[Mlir-commits] [mlir] [MLIR][Vector] Add unroll pattern for vector.shape_cast (PR #167738)

Nishant Patel llvmlistbot at llvm.org
Tue Nov 18 10:26:10 PST 2025


https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/167738

>From cd8b818297287afbed0c675d9bf491bfb296f385 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 11 Nov 2025 00:14:40 +0000
Subject: [PATCH 1/4] Add unroll pattern for vector.shape_cast

---
 .../mlir/Dialect/Vector/IR/VectorOps.td       |   1 +
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      |   4 +
 .../Vector/Transforms/VectorUnroll.cpp        | 170 +++++++++++++++++-
 .../Dialect/Vector/vector-unroll-options.mlir |  34 ++++
 .../Dialect/Vector/TestVectorTransforms.cpp   |   6 +
 5 files changed, 213 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 43172ff2082df..6ad179349f90f 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2427,6 +2427,7 @@ def Vector_CompressStoreOp :
 
 def Vector_ShapeCastOp :
   Vector_Op<"shape_cast", [Pure,
+    DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
     DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
   ]>,
     Arguments<(ins AnyVectorOfAnyRank:$source)>,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index daef0ba02100a..4cac137478fab 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6241,6 +6241,10 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
   setResultRanges(getResult(), argRanges.front());
 }
 
+std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
+  return llvm::to_vector<4>(getResultVectorType().getShape());
+}
+
 LogicalResult ShapeCastOp::verify() {
 
   VectorType sourceType = getSourceVectorType();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index fbae0989bed26..a4830809aaac8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1003,6 +1003,172 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
   vector::UnrollVectorOptions options;
 };
 
+static bool isContiguousExtract(ArrayRef<int64_t> targetShape,
+                                ArrayRef<int64_t> resultShape) {
+  if (targetShape.size() > resultShape.size()) {
+    return false;
+  }
+
+  size_t rankDiff = resultShape.size() - targetShape.size();
+  // Inner dimensions must match exactly & total resultElements should be
+  // evenly divisible by targetElements.
+  for (size_t i = 1; i < targetShape.size(); ++i) {
+    if (targetShape[i] != resultShape[rankDiff + i]) {
+      return false;
+    }
+  }
+
+  int64_t targetElements = ShapedType::getNumElements(targetShape);
+  int64_t resultElements = ShapedType::getNumElements(resultShape);
+  if (resultElements % targetElements != 0) {
+    return false;
+  }
+  return true;
+}
+
+// Calculate the shape to extract from source
+static std::optional<SmallVector<int64_t>>
+calculateSourceExtractShape(ArrayRef<int64_t> sourceShape,
+                            int64_t targetElements) {
+  SmallVector<int64_t> extractShape;
+  int64_t remainingElements = targetElements;
+
+  // Build extract shape from innermost dimension outward to ensure contiguity
+  for (int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) {
+    int64_t takeFromDim = std::min(remainingElements, sourceShape[i]);
+    extractShape.insert(extractShape.begin(), takeFromDim);
+
+    if (remainingElements % takeFromDim != 0) {
+      return std::nullopt; // Not evenly divisible
+    }
+    remainingElements /= takeFromDim;
+  }
+
+  // Fill remaining dimensions with 1
+  while (extractShape.size() < sourceShape.size()) {
+    extractShape.insert(extractShape.begin(), 1);
+  }
+
+  if (ShapedType::getNumElements(extractShape) != targetElements) {
+    return std::nullopt;
+  }
+
+  return extractShape;
+}
+
+// Convert result offsets to source offsets via linear position
+static SmallVector<int64_t>
+calculateSourceOffsets(ArrayRef<int64_t> resultOffsets,
+                       ArrayRef<int64_t> sourceStrides,
+                       ArrayRef<int64_t> resultStrides) {
+  // Convert result offsets to linear position
+  int64_t linearIndex = linearize(resultOffsets, resultStrides);
+  // Convert linear position to source offsets
+  SmallVector<int64_t> sourceOffsets = delinearize(linearIndex, sourceStrides);
+  return sourceOffsets;
+}
+
+/// This pattern unrolls `vector.shape_cast` operations according to the
+/// provided target unroll shape. It unrolls a large shape cast into smaller
+/// shape casts by extracting contiguous slices from the source vector, casting
+/// each slice to the target shape, and assembling the result by inserting each
+/// computed segment into the appropriate offset of the result vector.
+///
+/// This pattern only applies when contiguous slices can be extracted from the
+/// source vector and inserted into the result vector such that each slice
+/// remains a valid vector (and not decompose to scalars). In these cases, the
+/// unrolling proceeds as:
+/// vector.extract_strided_slice -> vector.shape_cast (on the slice) ->
+/// vector.insert_strided_slice
+///
+/// Example:
+///   Given a shape cast operation:
+///     %0 = vector.shape_cast %src : vector<8x2xf32> to vector<4x4xf32>
+///
+///   and a target unroll shape of <2x4>, the pattern produces:
+///
+///     %zero = arith.constant dense<0.0> : vector<4x4xf32>
+///     %s0 = vector.extract_strided_slice %src [0, 0], [4, 2], [1, 1]
+///       : vector<8x2xf32> to vector<4x2xf32>
+///     %sc0 = vector.shape_cast %s0 : vector<4x2xf32> to vector<2x4xf32>
+///     %i0 = vector.insert_strided_slice %sc0, %zero [0, 0], [1, 1]
+///       : vector<2x4xf32> into vector<4x4xf32>
+///     %s1 = vector.extract_strided_slice %src [4, 0], [4, 2], [1, 1]
+///       : vector<8x2xf32> to vector<4x2xf32>
+///     %sc1 = vector.shape_cast %s1 : vector<4x2xf32> to vector<2x4xf32>
+///     %i1 = vector.insert_strided_slice %sc1, %i0 [2, 0], [1, 1]
+///       : vector<2x4xf32> into vector<4x4xf32>
+///
+struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
+  UnrollShapeCastPattern(MLIRContext *context,
+                         const vector::UnrollVectorOptions &options,
+                         PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::ShapeCastOp>(context, benefit),
+        options(options) {}
+
+  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+                                PatternRewriter &rewriter) const override {
+    auto targetShape = getTargetShape(options, shapeCastOp);
+    if (!targetShape)
+      return failure();
+
+    VectorType sourceType = shapeCastOp.getSourceVectorType();
+    VectorType resultType = shapeCastOp.getResultVectorType();
+    ArrayRef<int64_t> sourceShape = sourceType.getShape();
+    ArrayRef<int64_t> resultShape = resultType.getShape();
+
+    if (!isContiguousExtract(*targetShape, resultShape)) {
+      return rewriter.notifyMatchFailure(shapeCastOp,
+                                         "Only supports cases where contiguous "
+                                         "extraction is possible");
+    }
+
+    int64_t targetElements = ShapedType::getNumElements(*targetShape);
+
+    // Calculate the shape to extract from source
+    auto extractShape =
+        calculateSourceExtractShape(sourceShape, targetElements);
+    if (!extractShape) {
+      return rewriter.notifyMatchFailure(
+          shapeCastOp,
+          "cannot extract target number of elements contiguously from source");
+    }
+
+    Location loc = shapeCastOp.getLoc();
+
+    // Create result vector initialized to zero
+    Value result = arith::ConstantOp::create(rewriter, loc, resultType,
+                                             rewriter.getZeroAttr(resultType));
+
+    VectorType targetType =
+        VectorType::get(*targetShape, sourceType.getElementType());
+
+    SmallVector<int64_t> extractStrides(extractShape->size(), 1);
+    SmallVector<int64_t> insertStrides(targetShape->size(), 1);
+    SmallVector<int64_t> sourceStrides = computeStrides(sourceShape);
+    SmallVector<int64_t> resultStrides = computeStrides(resultShape);
+
+    for (SmallVector<int64_t> resultOffsets :
+         StaticTileOffsetRange(resultShape, *targetShape)) {
+      SmallVector<int64_t> sourceOffsets =
+          calculateSourceOffsets(resultOffsets, sourceStrides, resultStrides);
+      Value sourceChunk = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+          loc, shapeCastOp.getSource(), sourceOffsets, *extractShape,
+          extractStrides);
+      Value targetChunk = rewriter.createOrFold<vector::ShapeCastOp>(
+          loc, targetType, sourceChunk);
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+          loc, targetChunk, result, resultOffsets, insertStrides);
+    }
+
+    rewriter.replaceOp(shapeCastOp, result);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
 } // namespace
 
 void mlir::vector::populateVectorUnrollPatterns(
@@ -1013,8 +1179,8 @@ void mlir::vector::populateVectorUnrollPatterns(
                UnrollReductionPattern, UnrollMultiReductionPattern,
                UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
                UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
-               UnrollToElements, UnrollStepPattern>(patterns.getContext(),
-                                                    options, benefit);
+               UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern>(
+      patterns.getContext(), options, benefit);
 }
 
 void mlir::vector::populateVectorToElementsUnrollPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index e5a98b5c67f33..c94a502fa3654 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -496,3 +496,37 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3
 // CHECK-COUNT-4:   arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32>
 // CHECK-NOT: arith.addf
 // CHECK: return
+
+
+func.func @shape_cast_1D(%v: vector<16xf32>) -> vector<2x2x4xf32> {
+  %0 = vector.shape_cast %v : vector<16xf32> to vector<2x2x4xf32>
+  return %0 : vector<2x2x4xf32>
+}
+
+// CHECK-LABEL: func @shape_cast_1D
+// CHECK-SAME: (%[[ARG0:.*]]: vector<16xf32>) -> vector<2x2x4xf32> {
+// CHECK:   %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x4xf32>
+// CHECK:   %[[S0:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32>
+// CHECK:   %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<8xf32> to vector<2x4xf32>
+// CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32>
+// CHECK:   %[[S1:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [8], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32>
+// CHECK:   %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<8xf32> to vector<2x4xf32>
+// CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32>
+// CHECK:   return %[[I1]] : vector<2x2x4xf32>
+
+
+func.func @shape_cast_2D(%v: vector<8x2xf32>) -> vector<4x4xf32> {
+  %0 = vector.shape_cast %v : vector<8x2xf32> to vector<4x4xf32>
+  return %0 : vector<4x4xf32>
+}
+
+// CHECK-LABEL: func @shape_cast_2D
+// CHECK-SAME: (%[[ARG0:.*]]: vector<8x2xf32>) -> vector<4x4xf32> {
+// CHECK:   %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
+// CHECK:   %[[S0:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32>
+// CHECK:   %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<4x2xf32> to vector<2x4xf32>
+// CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32>
+// CHECK:   %[[S1:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [4, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32>
+// CHECK:   %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<4x2xf32> to vector<2x4xf32>
+// CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32>
+// CHECK:   return %[[I1]] : vector<4x4xf32>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 79bfc9bbcda71..0ab4e451d544d 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -178,6 +178,12 @@ struct TestVectorUnrollingPatterns
                                      .setFilterConstraint([](Operation *op) {
                                        return success(isa<vector::StepOp>(op));
                                      }));
+    populateVectorUnrollPatterns(
+        patterns, UnrollVectorOptions()
+                      .setNativeShape(ArrayRef<int64_t>{2, 4})
+                      .setFilterConstraint([](Operation *op) {
+                        return success(isa<vector::ShapeCastOp>(op));
+                      }));
     populateVectorUnrollPatterns(
         patterns, UnrollVectorOptions()
                       .setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})

>From 73512fd722ea836ea96ec31d55f55e893c6f9b14 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 12 Nov 2025 19:38:26 +0000
Subject: [PATCH 2/4] Address feedback

---
 .../Vector/Transforms/VectorUnroll.cpp        | 59 ++++++++-----------
 1 file changed, 24 insertions(+), 35 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index a4830809aaac8..7afc83bb8a876 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1005,67 +1005,57 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
 
 static bool isContiguousExtract(ArrayRef<int64_t> targetShape,
                                 ArrayRef<int64_t> resultShape) {
-  if (targetShape.size() > resultShape.size()) {
+  if (targetShape.size() > resultShape.size())
     return false;
-  }
 
   size_t rankDiff = resultShape.size() - targetShape.size();
   // Inner dimensions must match exactly & total resultElements should be
   // evenly divisible by targetElements.
-  for (size_t i = 1; i < targetShape.size(); ++i) {
-    if (targetShape[i] != resultShape[rankDiff + i]) {
-      return false;
-    }
-  }
+  if (!llvm::equal(targetShape.drop_front(),
+                   resultShape.drop_front(rankDiff + 1)))
+    return false;
 
   int64_t targetElements = ShapedType::getNumElements(targetShape);
   int64_t resultElements = ShapedType::getNumElements(resultShape);
-  if (resultElements % targetElements != 0) {
-    return false;
-  }
-  return true;
+  return resultElements % targetElements == 0;
 }
 
-// Calculate the shape to extract from source
+// Calculate the shape to extract from source.
 static std::optional<SmallVector<int64_t>>
 calculateSourceExtractShape(ArrayRef<int64_t> sourceShape,
                             int64_t targetElements) {
   SmallVector<int64_t> extractShape;
   int64_t remainingElements = targetElements;
 
-  // Build extract shape from innermost dimension outward to ensure contiguity
+  // Build extract shape from innermost dimension outward to ensure contiguity.
   for (int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) {
     int64_t takeFromDim = std::min(remainingElements, sourceShape[i]);
     extractShape.insert(extractShape.begin(), takeFromDim);
 
-    if (remainingElements % takeFromDim != 0) {
-      return std::nullopt; // Not evenly divisible
-    }
+    if (remainingElements % takeFromDim != 0)
+      return std::nullopt; // Not evenly divisible.
     remainingElements /= takeFromDim;
   }
 
-  // Fill remaining dimensions with 1
-  while (extractShape.size() < sourceShape.size()) {
+  // Fill remaining dimensions with 1.
+  while (extractShape.size() < sourceShape.size())
     extractShape.insert(extractShape.begin(), 1);
-  }
 
-  if (ShapedType::getNumElements(extractShape) != targetElements) {
+  if (ShapedType::getNumElements(extractShape) != targetElements)
     return std::nullopt;
-  }
 
   return extractShape;
 }
 
-// Convert result offsets to source offsets via linear position
+// Convert result offsets to source offsets via linear position.
 static SmallVector<int64_t>
 calculateSourceOffsets(ArrayRef<int64_t> resultOffsets,
                        ArrayRef<int64_t> sourceStrides,
                        ArrayRef<int64_t> resultStrides) {
-  // Convert result offsets to linear position
+  // Convert result offsets to linear position.
   int64_t linearIndex = linearize(resultOffsets, resultStrides);
-  // Convert linear position to source offsets
-  SmallVector<int64_t> sourceOffsets = delinearize(linearIndex, sourceStrides);
-  return sourceOffsets;
+  // Convert linear position to source offsets.
+  return delinearize(linearIndex, sourceStrides);
 }
 
 /// This pattern unrolls `vector.shape_cast` operations according to the
@@ -1079,7 +1069,7 @@ calculateSourceOffsets(ArrayRef<int64_t> resultOffsets,
 /// remains a valid vector (and not decompose to scalars). In these cases, the
 /// unrolling proceeds as:
 /// vector.extract_strided_slice -> vector.shape_cast (on the slice) ->
-/// vector.insert_strided_slice
+/// vector.insert_strided_slice.
 ///
 /// Example:
 ///   Given a shape cast operation:
@@ -1108,7 +1098,8 @@ struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
 
   LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
                                 PatternRewriter &rewriter) const override {
-    auto targetShape = getTargetShape(options, shapeCastOp);
+    std::optional<SmallVector<int64_t>> targetShape =
+        getTargetShape(options, shapeCastOp);
     if (!targetShape)
       return failure();
 
@@ -1117,26 +1108,24 @@ struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
     ArrayRef<int64_t> sourceShape = sourceType.getShape();
     ArrayRef<int64_t> resultShape = resultType.getShape();
 
-    if (!isContiguousExtract(*targetShape, resultShape)) {
+    if (!isContiguousExtract(*targetShape, resultShape))
       return rewriter.notifyMatchFailure(shapeCastOp,
                                          "Only supports cases where contiguous "
                                          "extraction is possible");
-    }
 
     int64_t targetElements = ShapedType::getNumElements(*targetShape);
 
-    // Calculate the shape to extract from source
-    auto extractShape =
+    // Calculate the shape to extract from source.
+    std::optional<SmallVector<int64_t>> extractShape =
         calculateSourceExtractShape(sourceShape, targetElements);
-    if (!extractShape) {
+    if (!extractShape)
       return rewriter.notifyMatchFailure(
           shapeCastOp,
           "cannot extract target number of elements contiguously from source");
-    }
 
     Location loc = shapeCastOp.getLoc();
 
-    // Create result vector initialized to zero
+    // Create result vector initialized to zero.
     Value result = arith::ConstantOp::create(rewriter, loc, resultType,
                                              rewriter.getZeroAttr(resultType));
 

>From 9b4191a1c63c033fbf8f88dc9b227c1db1a936db Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 13 Nov 2025 00:40:23 +0000
Subject: [PATCH 3/4] Fix isContiguousExtract

---
 .../Vector/Transforms/VectorUnroll.cpp        | 50 ++++++++++++++++---
 1 file changed, 42 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 7afc83bb8a876..885fcf835c1a3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1008,16 +1008,50 @@ static bool isContiguousExtract(ArrayRef<int64_t> targetShape,
   if (targetShape.size() > resultShape.size())
     return false;
 
-  size_t rankDiff = resultShape.size() - targetShape.size();
-  // Inner dimensions must match exactly & total resultElements should be
-  // evenly divisible by targetElements.
-  if (!llvm::equal(targetShape.drop_front(),
-                   resultShape.drop_front(rankDiff + 1)))
-    return false;
-
   int64_t targetElements = ShapedType::getNumElements(targetShape);
   int64_t resultElements = ShapedType::getNumElements(resultShape);
-  return resultElements % targetElements == 0;
+
+  // Result must be evenly divisible by target.
+  if (resultElements % targetElements != 0)
+    return false;
+
+  // For contiguous extraction, we need to be able to
+  // extract targetElements contiguously from the result shape.
+  // This means we can "consume" dimensions from the innermost outward
+  // until we have exactly targetElements.
+
+  int64_t remainingElements = targetElements;
+  int targetDimIdx = targetShape.size() - 1;
+
+  // Work backwards through result dimensions.
+  for (int resultDimIdx = resultShape.size() - 1;
+       resultDimIdx >= 0 && remainingElements > 1 && targetDimIdx >= 0;
+       --resultDimIdx) {
+
+    int64_t resultDimSize = resultShape[resultDimIdx];
+    int64_t targetDimSize = targetShape[targetDimIdx];
+
+    if (targetDimSize > resultDimSize)
+      return false;
+
+    if (targetDimSize == resultDimSize) {
+      if (remainingElements % targetDimSize != 0)
+        return false;
+      remainingElements /= targetDimSize;
+      --targetDimIdx;
+    } else {
+      if (remainingElements != targetDimSize)
+        return false;
+      remainingElements = 1;
+      --targetDimIdx;
+    }
+  }
+
+  // Check remaining target dimensions are all 1 and we consumed all elements
+  return remainingElements == 1 &&
+         (targetDimIdx < 0 || llvm::all_of(
+                                  targetShape.take_front(targetDimIdx + 1),
+                                  [](int64_t d) { return d == 1; }));
 }
 
 // Calculate the shape to extract from source.

>From d4ea820d64c74a829225de31715be50a96045fa7 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 18 Nov 2025 16:52:35 +0000
Subject: [PATCH 4/4] Address feedback

---
 .../Vector/Transforms/VectorUnroll.cpp        | 110 +++++++++---------
 .../Dialect/Vector/vector-unroll-options.mlir |  39 ++++++-
 2 files changed, 88 insertions(+), 61 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 885fcf835c1a3..0a1d86109beea 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1003,58 +1003,60 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
   vector::UnrollVectorOptions options;
 };
 
-static bool isContiguousExtract(ArrayRef<int64_t> targetShape,
-                                ArrayRef<int64_t> resultShape) {
-  if (targetShape.size() > resultShape.size())
-    return false;
-
-  int64_t targetElements = ShapedType::getNumElements(targetShape);
-  int64_t resultElements = ShapedType::getNumElements(resultShape);
+/// Checks whether targetShape is contiguous in resultShape.
+/// For targetShape to be contiguous in resultShape:
+/// 1) The inner dimensions of targetShape and resultShape must match exactly.
+/// 2) The total number of elements in resultShape must be evenly divisible by
+///    the total number of elements in targetShape.
+/// Examples:
+///   isContiguous([4, 4], [8, 4]) == true
+///   isContiguous([2, 4], [8, 4]) == true
+///   isContiguous([2, 2], [8, 4]) == false
+/// Removes leading unit dimensions to handle cases like:
+///   isContiguous([1, 16], [1, 32]) == true
+static bool isContiguous(ArrayRef<int64_t> targetShape,
+                         ArrayRef<int64_t> resultShape) {
 
-  // Result must be evenly divisible by target.
-  if (resultElements % targetElements != 0)
+  if (targetShape.size() > resultShape.size())
     return false;
 
-  // For contiguous extraction, we need to be able to
-  // extract targetElements contiguously from the result shape.
-  // This means we can "consume" dimensions from the innermost outward
-  // until we have exactly targetElements.
+  while (!targetShape.empty() && targetShape.front() == 1) {
+    targetShape = targetShape.drop_front();
+  }
 
-  int64_t remainingElements = targetElements;
-  int targetDimIdx = targetShape.size() - 1;
-
-  // Work backwards through result dimensions.
-  for (int resultDimIdx = resultShape.size() - 1;
-       resultDimIdx >= 0 && remainingElements > 1 && targetDimIdx >= 0;
-       --resultDimIdx) {
-
-    int64_t resultDimSize = resultShape[resultDimIdx];
-    int64_t targetDimSize = targetShape[targetDimIdx];
-
-    if (targetDimSize > resultDimSize)
-      return false;
-
-    if (targetDimSize == resultDimSize) {
-      if (remainingElements % targetDimSize != 0)
-        return false;
-      remainingElements /= targetDimSize;
-      --targetDimIdx;
-    } else {
-      if (remainingElements != targetDimSize)
-        return false;
-      remainingElements = 1;
-      --targetDimIdx;
-    }
+  while (!resultShape.empty() && resultShape.front() == 1) {
+    resultShape = resultShape.drop_front();
   }
 
-  // Check remaining target dimensions are all 1 and we consumed all elements
-  return remainingElements == 1 &&
-         (targetDimIdx < 0 || llvm::all_of(
-                                  targetShape.take_front(targetDimIdx + 1),
-                                  [](int64_t d) { return d == 1; }));
+  size_t rankDiff = resultShape.size() - targetShape.size();
+  if (!llvm::equal(targetShape.drop_front(),
+                   resultShape.drop_front(rankDiff + 1)))
+    return false;
+
+  int64_t targetElements = ShapedType::getNumElements(targetShape);
+  int64_t resultElements = ShapedType::getNumElements(resultShape);
+  return resultElements % targetElements == 0;
 }
 
-// Calculate the shape to extract from source.
+/// This function determines what shape to use with
+/// `vector.extract_strided_slice` to extract a contiguous memory region from a
+/// source vector. The extraction must be contiguous and contain exactly the
+/// specified number of elements. If such an extraction shape cannot be
+/// determined, the function returns std::nullopt.
+/// Examples:
+///   sourceShape = [16], targetElements = 8
+///   Working right-to-left:
+///   - Take min(8, 16) = 8 from only dim → extractShape = [8],
+///     remaining = 8/8 = 1
+///     Result: [8]
+///
+///   sourceShape = [4, 4], targetElements = 8
+///   Working right-to-left:
+///   - Take min(8, 4) = 4 from last dim → extractShape = [4],
+///     remaining = 8/4 = 2
+///   - Take min(2, 4) = 2 from first dim → extractShape = [2, 4],
+///     remaining = 2/2 = 1
+///     Result: [2, 4]
 static std::optional<SmallVector<int64_t>>
 calculateSourceExtractShape(ArrayRef<int64_t> sourceShape,
                             int64_t targetElements) {
@@ -1084,12 +1086,12 @@ calculateSourceExtractShape(ArrayRef<int64_t> sourceShape,
 // Convert result offsets to source offsets via linear position.
 static SmallVector<int64_t>
 calculateSourceOffsets(ArrayRef<int64_t> resultOffsets,
-                       ArrayRef<int64_t> sourceStrides,
-                       ArrayRef<int64_t> resultStrides) {
+                       ArrayRef<int64_t> sourceShape,
+                       ArrayRef<int64_t> resultShape) {
   // Convert result offsets to linear position.
-  int64_t linearIndex = linearize(resultOffsets, resultStrides);
+  int64_t linearIndex = linearize(resultOffsets, computeStrides(resultShape));
   // Convert linear position to source offsets.
-  return delinearize(linearIndex, sourceStrides);
+  return delinearize(linearIndex, computeStrides(sourceShape));
 }
 
 /// This pattern unrolls `vector.shape_cast` operations according to the
@@ -1142,10 +1144,10 @@ struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
     ArrayRef<int64_t> sourceShape = sourceType.getShape();
     ArrayRef<int64_t> resultShape = resultType.getShape();
 
-    if (!isContiguousExtract(*targetShape, resultShape))
-      return rewriter.notifyMatchFailure(shapeCastOp,
-                                         "Only supports cases where contiguous "
-                                         "extraction is possible");
+    if (!isContiguous(*targetShape, resultShape))
+      return rewriter.notifyMatchFailure(
+          shapeCastOp, "Only supports cases where target shape is "
+                       "contiguous in result vector shape");
 
     int64_t targetElements = ShapedType::getNumElements(*targetShape);
 
@@ -1168,13 +1170,11 @@ struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
 
     SmallVector<int64_t> extractStrides(extractShape->size(), 1);
     SmallVector<int64_t> insertStrides(targetShape->size(), 1);
-    SmallVector<int64_t> sourceStrides = computeStrides(sourceShape);
-    SmallVector<int64_t> resultStrides = computeStrides(resultShape);
 
     for (SmallVector<int64_t> resultOffsets :
          StaticTileOffsetRange(resultShape, *targetShape)) {
       SmallVector<int64_t> sourceOffsets =
-          calculateSourceOffsets(resultOffsets, sourceStrides, resultStrides);
+          calculateSourceOffsets(resultOffsets, sourceShape, resultShape);
       Value sourceChunk = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
           loc, shapeCastOp.getSource(), sourceOffsets, *extractShape,
           extractStrides);
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index c94a502fa3654..8e2caa39696cb 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -504,12 +504,12 @@ func.func @shape_cast_1D(%v: vector<16xf32>) -> vector<2x2x4xf32> {
 }
 
 // CHECK-LABEL: func @shape_cast_1D
-// CHECK-SAME: (%[[ARG0:.*]]: vector<16xf32>) -> vector<2x2x4xf32> {
+// CHECK-SAME: (%[[V:.*]]: vector<16xf32>) -> vector<2x2x4xf32> {
 // CHECK:   %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x4xf32>
-// CHECK:   %[[S0:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32>
+// CHECK:   %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32>
 // CHECK:   %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<8xf32> to vector<2x4xf32>
 // CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32>
-// CHECK:   %[[S1:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [8], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32>
+// CHECK:   %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [8], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32>
 // CHECK:   %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<8xf32> to vector<2x4xf32>
 // CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32>
 // CHECK:   return %[[I1]] : vector<2x2x4xf32>
@@ -521,12 +521,39 @@ func.func @shape_cast_2D(%v: vector<8x2xf32>) -> vector<4x4xf32> {
 }
 
 // CHECK-LABEL: func @shape_cast_2D
-// CHECK-SAME: (%[[ARG0:.*]]: vector<8x2xf32>) -> vector<4x4xf32> {
+// CHECK-SAME: (%[[V:.*]]: vector<8x2xf32>) -> vector<4x4xf32> {
 // CHECK:   %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
-// CHECK:   %[[S0:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32>
+// CHECK:   %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32>
 // CHECK:   %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<4x2xf32> to vector<2x4xf32>
 // CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32>
-// CHECK:   %[[S1:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [4, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32>
+// CHECK:   %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [4, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32>
 // CHECK:   %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<4x2xf32> to vector<2x4xf32>
 // CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32>
 // CHECK:   return %[[I1]] : vector<4x4xf32>
+
+
+// This is a negative test case to ensure that such shape casts are not unrolled
+// because the targetShape (2x4) is not contiguous in result vector
+func.func @negative_shape_cast_target_shape_not_contiguous(%v: vector<64xf32>) -> vector<8x8xf32> {
+  %0 = vector.shape_cast %v : vector<64xf32> to vector<8x8xf32>
+  return %0 : vector<8x8xf32>
+}
+
+// CHECK-LABEL: func @negative_shape_cast_target_shape_not_contiguous
+// CHECK-SAME: (%[[V:.*]]: vector<64xf32>) -> vector<8x8xf32> {
+// CHECK:   %[[SC:.*]] = vector.shape_cast %[[V]] : vector<64xf32> to vector<8x8xf32>
+// CHECK:   return %[[SC]] : vector<8x8xf32>
+
+
+// This is negative test case to ensure that such shape casts are not unrolled
+// because it cannot determine the extractShape from source vector (8x3)
+// to extract conitguous targetShape (2x4)
+func.func @negative_shape_cast_source_shape_not_determinable(%v: vector<8x3xf32>) -> vector<6x4xf32> {
+  %0 = vector.shape_cast %v : vector<8x3xf32> to vector<6x4xf32>
+  return %0 : vector<6x4xf32>
+}
+
+// CHECK-LABEL: func @negative_shape_cast_source_shape_not_determinable
+// CHECK-SAME: (%[[V:.*]]: vector<8x3xf32>) -> vector<6x4xf32> {
+// CHECK:   %[[SC:.*]] = vector.shape_cast %[[V]] : vector<8x3xf32> to vector<6x4xf32>
+// CHECK:   return %[[SC]] : vector<6x4xf32>



More information about the Mlir-commits mailing list