[Mlir-commits] [mlir] [MLIR][Vector] Extend elementwise pattern to support unrolling from higher rank to lower rank (PR #162515)

Nishant Patel llvmlistbot at llvm.org
Tue Oct 14 09:07:29 PDT 2025


https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/162515

>From f854b2dcf748a0ff9e5e39d5fae70e3cd437f774 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 8 Oct 2025 16:37:37 +0000
Subject: [PATCH 1/8] Extend elementwise to support unrolling from higher rank
 to lower rank

---
 .../Vector/Transforms/VectorUnroll.cpp        | 52 +++++++++++-----
 .../Dialect/Vector/vector-unroll-options.mlir | 60 +++++++++++++++++--
 2 files changed, 92 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 14639c5f1cdd3..62d65e28e8c2e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -468,23 +468,30 @@ struct UnrollElementwisePattern : public RewritePattern {
     auto dstVecType = cast<VectorType>(op->getResult(0).getType());
     SmallVector<int64_t> originalSize =
         *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
-    // Bail-out if rank(source) != rank(target). The main limitation here is the
-    // fact that `ExtractStridedSlice` requires the rank for the input and
-    // output to match. If needed, we can relax this later.
-    if (originalSize.size() != targetShape->size())
-      return rewriter.notifyMatchFailure(
-          op, "expected input vector rank to match target shape rank");
+
     Location loc = op->getLoc();
+
+    // Handle rank mismatch by adding leading unit dimensions to targetShape
+    SmallVector<int64_t> adjustedTargetShape = *targetShape;
+    SmallVector<int64_t> adjustedOffsets;
+    if (originalSize.size() > targetShape->size()) {
+      // Add leading unit dimensions to targetShape
+      int64_t rankDiff = originalSize.size() - targetShape->size();
+      adjustedTargetShape.insert(adjustedTargetShape.begin(), rankDiff, 1);
+    }
+
     // Prepare the result vector.
     Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
                                              rewriter.getZeroAttr(dstVecType));
-    SmallVector<int64_t> strides(targetShape->size(), 1);
-    VectorType newVecType =
+    SmallVector<int64_t> strides(adjustedTargetShape.size(), 1);
+    VectorType extractVecType =
+        VectorType::get(adjustedTargetShape, dstVecType.getElementType());
+    VectorType computeVecType =
         VectorType::get(*targetShape, dstVecType.getElementType());
 
     // Create the unrolled computation.
     for (SmallVector<int64_t> offsets :
-         StaticTileOffsetRange(originalSize, *targetShape)) {
+         StaticTileOffsetRange(originalSize, adjustedTargetShape)) {
       SmallVector<Value> extractOperands;
       for (OpOperand &operand : op->getOpOperands()) {
         auto vecType = dyn_cast<VectorType>(operand.get().getType());
@@ -492,14 +499,31 @@ struct UnrollElementwisePattern : public RewritePattern {
           extractOperands.push_back(operand.get());
           continue;
         }
-        extractOperands.push_back(
-            rewriter.createOrFold<vector::ExtractStridedSliceOp>(
-                loc, operand.get(), offsets, *targetShape, strides));
+        Value extracted = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+            loc, operand.get(), offsets, adjustedTargetShape, strides);
+
+        // Reshape to remove leading unit dims if needed
+        if (adjustedTargetShape.size() > targetShape->size()) {
+          extracted = rewriter.createOrFold<vector::ShapeCastOp>(
+              loc, VectorType::get(*targetShape, vecType.getElementType()),
+              extracted);
+        }
+        extractOperands.push_back(extracted);
       }
+
       Operation *newOp = cloneOpWithOperandsAndTypes(
-          rewriter, loc, op, extractOperands, newVecType);
+          rewriter, loc, op, extractOperands, computeVecType);
+
+      Value computeResult = newOp->getResult(0);
+
+      // Reshape back to higher rank if needed for insertion
+      if (adjustedTargetShape.size() > targetShape->size()) {
+        computeResult = rewriter.createOrFold<vector::ShapeCastOp>(
+            loc, extractVecType, computeResult);
+      }
+
       result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
-          loc, newOp->getResult(0), result, offsets, strides);
+          loc, computeResult, result, offsets, strides);
     }
     rewriter.replaceOp(op, result);
     return success();
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 35db14e0f7f1d..a26e4b0baa05b 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -188,15 +188,40 @@ func.func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf
 //   CHECK-LABEL: func @vector_fma
 // CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>
 
-// TODO: We should be able to unroll this like the example above - this will require extending UnrollElementwisePattern.
-func.func @negative_vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
+func.func @vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
   %0 = vector.fma %a, %a, %a : vector<3x2x2xf32>
   return %0 : vector<3x2x2xf32>
 }
-// CHECK-LABEL: func @negative_vector_fma_3d
-//   CHECK-NOT: vector.extract_strided_slice
-//       CHECK: %[[R0:.*]] = vector.fma %{{.+}} : vector<3x2x2xf32>
-//       CHECK: return
+// CHECK-LABEL: func @vector_fma_3d
+//       CHECK:   %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<3x2x2xf32>
+//       CHECK:   %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S0:.*]] = vector.shape_cast %[[E0]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S1:.*]] = vector.shape_cast %[[E1]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S2:.*]] = vector.shape_cast %[[E2]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[FMA0:.*]] = vector.fma %[[S0]], %[[S1]], %[[S2]] : vector<2x2xf32>
+//       CHECK:   %[[SC0:.*]] = vector.shape_cast %[[FMA0]] : vector<2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<3x2x2xf32>
+//       CHECK:   %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S3:.*]] = vector.shape_cast %[[E3]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S4:.*]] = vector.shape_cast %[[E4]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E5:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S5:.*]] = vector.shape_cast %[[E5]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[FMA1:.*]] = vector.fma %[[S3]], %[[S4]], %[[S5]] : vector<2x2xf32>
+//       CHECK:   %[[SC1:.*]] = vector.shape_cast %[[FMA1]] : vector<2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<3x2x2xf32>
+//       CHECK:   %[[E6:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S6:.*]] = vector.shape_cast %[[E6]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E7:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S7:.*]] = vector.shape_cast %[[E7]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E8:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S8:.*]] = vector.shape_cast %[[E8]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[FMA2:.*]] = vector.fma %[[S6]], %[[S7]], %[[S8]] : vector<2x2xf32>
+//       CHECK:   %[[SC2:.*]] = vector.shape_cast %[[FMA2]] : vector<2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[I2:.*]] = vector.insert_strided_slice %[[SC2]], %[[I1]] {offsets = [2, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<3x2x2xf32>
+//       CHECK:   return %[[I2]] : vector<3x2x2xf32>
 
 func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
   %0 = vector.multi_reduction #vector.kind<add>, %v, %acc [1] : vector<4x6xf32> to vector<4xf32>
@@ -440,3 +465,26 @@ func.func @vector_step() -> vector<32xindex> {
 // CHECK: %[[ADD3:.*]] = arith.addi %[[STEP]], %[[CST]] : vector<8xindex>
 // CHECK: %[[INS3:.*]] = vector.insert_strided_slice %[[ADD3]], %[[INS2]] {offsets = [24], strides = [1]} : vector<8xindex> into vector<32xindex>
 // CHECK: return %[[INS3]] : vector<32xindex>
+
+
+func.func @elementwise(%v1: vector<2x2x2xf32>, %v2: vector<2x2x2xf32>) -> vector<2x2x2xf32> {
+  %0 = arith.addf %v1, %v2 : vector<2x2x2xf32>
+  return %0 : vector<2x2x2xf32>
+}
+// CHECK-LABEL: func @elementwise
+//       CHECK:   %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x2xf32>
+//       CHECK:   %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S0:.*]] = vector.shape_cast %[[E0]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S1:.*]] = vector.shape_cast %[[E1]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[ADD0:.*]] = arith.addf %[[S0]], %[[S1]] : vector<2x2xf32>
+//       CHECK:   %[[SC0:.*]] = vector.shape_cast %[[ADD0]] : vector<2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<2x2x2xf32>
+//       CHECK:   %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S2:.*]] = vector.shape_cast %[[E2]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S3:.*]] = vector.shape_cast %[[E3]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[ADD1:.*]] = arith.addf %[[S2]], %[[S3]] : vector<2x2xf32>
+//       CHECK:   %[[SC1:.*]] = vector.shape_cast %[[ADD1]] : vector<2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<2x2x2xf32>
+//       CHECK:   return %[[I1]] : vector<2x2x2xf32>

>From 01357d9b7dc5f0707ab8220bb7eb561935f88d35 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 9 Oct 2025 17:26:21 +0000
Subject: [PATCH 2/8] remove reshape before insertStride

---
 .../Dialect/Vector/Transforms/VectorUnroll.cpp    | 14 ++++++--------
 .../Dialect/Vector/vector-unroll-options.mlir     | 15 +++++----------
 2 files changed, 11 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 62d65e28e8c2e..dd01873ded05e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -484,8 +484,6 @@ struct UnrollElementwisePattern : public RewritePattern {
     Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
                                              rewriter.getZeroAttr(dstVecType));
     SmallVector<int64_t> strides(adjustedTargetShape.size(), 1);
-    VectorType extractVecType =
-        VectorType::get(adjustedTargetShape, dstVecType.getElementType());
     VectorType computeVecType =
         VectorType::get(*targetShape, dstVecType.getElementType());
 
@@ -516,14 +514,14 @@ struct UnrollElementwisePattern : public RewritePattern {
 
       Value computeResult = newOp->getResult(0);
 
-      // Reshape back to higher rank if needed for insertion
-      if (adjustedTargetShape.size() > targetShape->size()) {
-        computeResult = rewriter.createOrFold<vector::ShapeCastOp>(
-            loc, extractVecType, computeResult);
-      }
+      // Use strides sized to targetShape for proper insertion
+      SmallVector<int64_t> insertStrides =
+          (adjustedTargetShape.size() > targetShape->size())
+              ? SmallVector<int64_t>(targetShape->size(), 1)
+              : strides;
 
       result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
-          loc, computeResult, result, offsets, strides);
+          loc, computeResult, result, offsets, insertStrides);
     }
     rewriter.replaceOp(op, result);
     return success();
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index a26e4b0baa05b..feea7c0f74531 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -201,8 +201,7 @@ func.func @vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
 //       CHECK:   %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
 //       CHECK:   %[[S2:.*]] = vector.shape_cast %[[E2]] : vector<1x2x2xf32> to vector<2x2xf32>
 //       CHECK:   %[[FMA0:.*]] = vector.fma %[[S0]], %[[S1]], %[[S2]] : vector<2x2xf32>
-//       CHECK:   %[[SC0:.*]] = vector.shape_cast %[[FMA0]] : vector<2x2xf32> to vector<1x2x2xf32>
-//       CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<3x2x2xf32>
+//       CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[FMA0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
 //       CHECK:   %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
 //       CHECK:   %[[S3:.*]] = vector.shape_cast %[[E3]] : vector<1x2x2xf32> to vector<2x2xf32>
 //       CHECK:   %[[E4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
@@ -210,8 +209,7 @@ func.func @vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
 //       CHECK:   %[[E5:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
 //       CHECK:   %[[S5:.*]] = vector.shape_cast %[[E5]] : vector<1x2x2xf32> to vector<2x2xf32>
 //       CHECK:   %[[FMA1:.*]] = vector.fma %[[S3]], %[[S4]], %[[S5]] : vector<2x2xf32>
-//       CHECK:   %[[SC1:.*]] = vector.shape_cast %[[FMA1]] : vector<2x2xf32> to vector<1x2x2xf32>
-//       CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<3x2x2xf32>
+//       CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[FMA1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
 //       CHECK:   %[[E6:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
 //       CHECK:   %[[S6:.*]] = vector.shape_cast %[[E6]] : vector<1x2x2xf32> to vector<2x2xf32>
 //       CHECK:   %[[E7:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
@@ -219,8 +217,7 @@ func.func @vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
 //       CHECK:   %[[E8:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
 //       CHECK:   %[[S8:.*]] = vector.shape_cast %[[E8]] : vector<1x2x2xf32> to vector<2x2xf32>
 //       CHECK:   %[[FMA2:.*]] = vector.fma %[[S6]], %[[S7]], %[[S8]] : vector<2x2xf32>
-//       CHECK:   %[[SC2:.*]] = vector.shape_cast %[[FMA2]] : vector<2x2xf32> to vector<1x2x2xf32>
-//       CHECK:   %[[I2:.*]] = vector.insert_strided_slice %[[SC2]], %[[I1]] {offsets = [2, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<3x2x2xf32>
+//       CHECK:   %[[I2:.*]] = vector.insert_strided_slice %[[FMA2]], %[[I1]] {offsets = [2, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
 //       CHECK:   return %[[I2]] : vector<3x2x2xf32>
 
 func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
@@ -478,13 +475,11 @@ func.func @elementwise(%v1: vector<2x2x2xf32>, %v2: vector<2x2x2xf32>) -> vector
 //       CHECK:   %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
 //       CHECK:   %[[S1:.*]] = vector.shape_cast %[[E1]] : vector<1x2x2xf32> to vector<2x2xf32>
 //       CHECK:   %[[ADD0:.*]] = arith.addf %[[S0]], %[[S1]] : vector<2x2xf32>
-//       CHECK:   %[[SC0:.*]] = vector.shape_cast %[[ADD0]] : vector<2x2xf32> to vector<1x2x2xf32>
-//       CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<2x2x2xf32>
+//       CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[ADD0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2xf32>
 //       CHECK:   %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
 //       CHECK:   %[[S2:.*]] = vector.shape_cast %[[E2]] : vector<1x2x2xf32> to vector<2x2xf32>
 //       CHECK:   %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
 //       CHECK:   %[[S3:.*]] = vector.shape_cast %[[E3]] : vector<1x2x2xf32> to vector<2x2xf32>
 //       CHECK:   %[[ADD1:.*]] = arith.addf %[[S2]], %[[S3]] : vector<2x2xf32>
-//       CHECK:   %[[SC1:.*]] = vector.shape_cast %[[ADD1]] : vector<2x2xf32> to vector<1x2x2xf32>
-//       CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<2x2x2xf32>
+//       CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2xf32>
 //       CHECK:   return %[[I1]] : vector<2x2x2xf32>

>From d59390bd29204430d5d64714a4f7792dbc1e8718 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 9 Oct 2025 17:43:01 +0000
Subject: [PATCH 3/8] Unused variable

---
 mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index dd01873ded05e..613819232c69b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -473,7 +473,6 @@ struct UnrollElementwisePattern : public RewritePattern {
 
     // Handle rank mismatch by adding leading unit dimensions to targetShape
     SmallVector<int64_t> adjustedTargetShape = *targetShape;
-    SmallVector<int64_t> adjustedOffsets;
     if (originalSize.size() > targetShape->size()) {
       // Add leading unit dimensions to targetShape
       int64_t rankDiff = originalSize.size() - targetShape->size();

>From 7f567ecadfa16e53ee7da107f19cfa9a3b244de8 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 13 Oct 2025 21:16:21 +0000
Subject: [PATCH 4/8] Address feedback

---
 .../Vector/Transforms/VectorUnroll.cpp        | 25 ++++++-----
 .../Dialect/Vector/vector-unroll-options.mlir | 43 ++++++++++---------
 2 files changed, 36 insertions(+), 32 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 613819232c69b..3b0df2bab9717 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -465,24 +465,27 @@ struct UnrollElementwisePattern : public RewritePattern {
     auto targetShape = getTargetShape(options, op);
     if (!targetShape)
       return failure();
+    int64_t targetShapeRank = targetShape->size();
     auto dstVecType = cast<VectorType>(op->getResult(0).getType());
     SmallVector<int64_t> originalSize =
         *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
+    int64_t originalShapeRank = originalSize.size();
 
     Location loc = op->getLoc();
 
     // Handle rank mismatch by adding leading unit dimensions to targetShape
-    SmallVector<int64_t> adjustedTargetShape = *targetShape;
-    if (originalSize.size() > targetShape->size()) {
-      // Add leading unit dimensions to targetShape
-      int64_t rankDiff = originalSize.size() - targetShape->size();
-      adjustedTargetShape.insert(adjustedTargetShape.begin(), rankDiff, 1);
-    }
-
+    SmallVector<int64_t> adjustedTargetShape(originalShapeRank);
+    int64_t rankDiff = originalShapeRank - targetShapeRank;
+    std::fill(adjustedTargetShape.begin(),
+              adjustedTargetShape.begin() + rankDiff, 1);
+    std::copy(targetShape->begin(), targetShape->end(),
+              adjustedTargetShape.begin() + rankDiff);
+
+    int64_t adjustedTargetShapeRank = adjustedTargetShape.size();
     // Prepare the result vector.
     Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
                                              rewriter.getZeroAttr(dstVecType));
-    SmallVector<int64_t> strides(adjustedTargetShape.size(), 1);
+    SmallVector<int64_t> strides(adjustedTargetShapeRank, 1);
     VectorType computeVecType =
         VectorType::get(*targetShape, dstVecType.getElementType());
 
@@ -500,7 +503,7 @@ struct UnrollElementwisePattern : public RewritePattern {
             loc, operand.get(), offsets, adjustedTargetShape, strides);
 
         // Reshape to remove leading unit dims if needed
-        if (adjustedTargetShape.size() > targetShape->size()) {
+        if (adjustedTargetShapeRank > targetShapeRank) {
           extracted = rewriter.createOrFold<vector::ShapeCastOp>(
               loc, VectorType::get(*targetShape, vecType.getElementType()),
               extracted);
@@ -515,8 +518,8 @@ struct UnrollElementwisePattern : public RewritePattern {
 
       // Use strides sized to targetShape for proper insertion
       SmallVector<int64_t> insertStrides =
-          (adjustedTargetShape.size() > targetShape->size())
-              ? SmallVector<int64_t>(targetShape->size(), 1)
+          (adjustedTargetShapeRank > targetShapeRank)
+              ? SmallVector<int64_t>(targetShapeRank, 1)
               : strides;
 
       result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index feea7c0f74531..ca06c037c4b11 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -193,30 +193,31 @@ func.func @vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
   return %0 : vector<3x2x2xf32>
 }
 // CHECK-LABEL: func @vector_fma_3d
+//  CHECK-SAME:   (%[[SRC:.*]]: vector<3x2x2xf32>) -> vector<3x2x2xf32> {
 //       CHECK:   %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<3x2x2xf32>
-//       CHECK:   %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
-//       CHECK:   %[[S0:.*]] = vector.shape_cast %[[E0]] : vector<1x2x2xf32> to vector<2x2xf32>
-//       CHECK:   %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
-//       CHECK:   %[[S1:.*]] = vector.shape_cast %[[E1]] : vector<1x2x2xf32> to vector<2x2xf32>
-//       CHECK:   %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
-//       CHECK:   %[[S2:.*]] = vector.shape_cast %[[E2]] : vector<1x2x2xf32> to vector<2x2xf32>
-//       CHECK:   %[[FMA0:.*]] = vector.fma %[[S0]], %[[S1]], %[[S2]] : vector<2x2xf32>
+//       CHECK:   %[[E_LHS_0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S_LHS_0:.*]] = vector.shape_cast %[[E_LHS_0]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E_RHS_0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S_RHS_0:.*]] = vector.shape_cast %[[E_RHS_0]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E_OUT_0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S_OUT_0:.*]] = vector.shape_cast %[[E_OUT_0]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[FMA0:.*]] = vector.fma %[[S_LHS_0]], %[[S_RHS_0]], %[[S_OUT_0]] : vector<2x2xf32>
 //       CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[FMA0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
-//       CHECK:   %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
-//       CHECK:   %[[S3:.*]] = vector.shape_cast %[[E3]] : vector<1x2x2xf32> to vector<2x2xf32>
-//       CHECK:   %[[E4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
-//       CHECK:   %[[S4:.*]] = vector.shape_cast %[[E4]] : vector<1x2x2xf32> to vector<2x2xf32>
-//       CHECK:   %[[E5:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
-//       CHECK:   %[[S5:.*]] = vector.shape_cast %[[E5]] : vector<1x2x2xf32> to vector<2x2xf32>
-//       CHECK:   %[[FMA1:.*]] = vector.fma %[[S3]], %[[S4]], %[[S5]] : vector<2x2xf32>
+//       CHECK:   %[[E_LHS_1:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S_LHS_1:.*]] = vector.shape_cast %[[E_LHS_1]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E_RHS_1:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S_RHS_1:.*]] = vector.shape_cast %[[E_RHS_1]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E_OUT_1:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S_OUT_1:.*]] = vector.shape_cast %[[E_OUT_1]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[FMA1:.*]] = vector.fma %[[S_LHS_1]], %[[S_RHS_1]], %[[S_OUT_1]] : vector<2x2xf32>
 //       CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[FMA1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
-//       CHECK:   %[[E6:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
-//       CHECK:   %[[S6:.*]] = vector.shape_cast %[[E6]] : vector<1x2x2xf32> to vector<2x2xf32>
-//       CHECK:   %[[E7:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
-//       CHECK:   %[[S7:.*]] = vector.shape_cast %[[E7]] : vector<1x2x2xf32> to vector<2x2xf32>
-//       CHECK:   %[[E8:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
-//       CHECK:   %[[S8:.*]] = vector.shape_cast %[[E8]] : vector<1x2x2xf32> to vector<2x2xf32>
-//       CHECK:   %[[FMA2:.*]] = vector.fma %[[S6]], %[[S7]], %[[S8]] : vector<2x2xf32>
+//       CHECK:   %[[E_LHS_2:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S_LHS_2:.*]] = vector.shape_cast %[[E_LHS_2]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E_RHS_2:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S_RHS_2:.*]] = vector.shape_cast %[[E_RHS_2]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E_OUT_2:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S_OUT_2:.*]] = vector.shape_cast %[[E_OUT_2]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[FMA2:.*]] = vector.fma %[[S_LHS_2]], %[[S_RHS_2]], %[[S_OUT_2]] : vector<2x2xf32>
 //       CHECK:   %[[I2:.*]] = vector.insert_strided_slice %[[FMA2]], %[[I1]] {offsets = [2, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
 //       CHECK:   return %[[I2]] : vector<3x2x2xf32>
 

>From 5052b469d0a22188054c0b3d1c3716ee3589f3a8 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 14 Oct 2025 00:30:23 +0000
Subject: [PATCH 5/8] Fix CHECK

---
 mlir/test/Dialect/Vector/vector-unroll-options.mlir | 9 +++++----
 1 file changed, 5 insertions(+), 4 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index ca06c037c4b11..afec43101132e 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -470,16 +470,17 @@ func.func @elementwise(%v1: vector<2x2x2xf32>, %v2: vector<2x2x2xf32>) -> vector
   return %0 : vector<2x2x2xf32>
 }
 // CHECK-LABEL: func @elementwise
+//  CHECK-SAME: (%[[ARG0:.*]]: vector<2x2x2xf32>, %[[ARG1:.*]]: vector<2x2x2xf32>) -> vector<2x2x2xf32> {
 //       CHECK:   %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x2xf32>
-//       CHECK:   %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[E0:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
 //       CHECK:   %[[S0:.*]] = vector.shape_cast %[[E0]] : vector<1x2x2xf32> to vector<2x2xf32>
-//       CHECK:   %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[E1:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
 //       CHECK:   %[[S1:.*]] = vector.shape_cast %[[E1]] : vector<1x2x2xf32> to vector<2x2xf32>
 //       CHECK:   %[[ADD0:.*]] = arith.addf %[[S0]], %[[S1]] : vector<2x2xf32>
 //       CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[ADD0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2xf32>
-//       CHECK:   %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[E2:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
 //       CHECK:   %[[S2:.*]] = vector.shape_cast %[[E2]] : vector<1x2x2xf32> to vector<2x2xf32>
-//       CHECK:   %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[E3:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
 //       CHECK:   %[[S3:.*]] = vector.shape_cast %[[E3]] : vector<1x2x2xf32> to vector<2x2xf32>
 //       CHECK:   %[[ADD1:.*]] = arith.addf %[[S2]], %[[S3]] : vector<2x2xf32>
 //       CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2xf32>

>From 61c6edaab8bb687f094f62c842d7a8be701d0d51 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 14 Oct 2025 01:26:43 +0000
Subject: [PATCH 6/8] Rename CHECK variable

---
 .../Dialect/Vector/vector-unroll-options.mlir | 20 +++++++++----------
 1 file changed, 10 insertions(+), 10 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index afec43101132e..9f5080777b7e7 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -472,16 +472,16 @@ func.func @elementwise(%v1: vector<2x2x2xf32>, %v2: vector<2x2x2xf32>) -> vector
 // CHECK-LABEL: func @elementwise
 //  CHECK-SAME: (%[[ARG0:.*]]: vector<2x2x2xf32>, %[[ARG1:.*]]: vector<2x2x2xf32>) -> vector<2x2x2xf32> {
 //       CHECK:   %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x2xf32>
-//       CHECK:   %[[E0:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
-//       CHECK:   %[[S0:.*]] = vector.shape_cast %[[E0]] : vector<1x2x2xf32> to vector<2x2xf32>
-//       CHECK:   %[[E1:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
-//       CHECK:   %[[S1:.*]] = vector.shape_cast %[[E1]] : vector<1x2x2xf32> to vector<2x2xf32>
-//       CHECK:   %[[ADD0:.*]] = arith.addf %[[S0]], %[[S1]] : vector<2x2xf32>
+//       CHECK:   %[[E_LHS_0:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S_LHS_0:.*]] = vector.shape_cast %[[E_LHS_0]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E_RHS_0:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S_RHS_0:.*]] = vector.shape_cast %[[E_RHS_0]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[ADD0:.*]] = arith.addf %[[S_LHS_0]], %[[S_RHS_0]] : vector<2x2xf32>
 //       CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[ADD0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2xf32>
-//       CHECK:   %[[E2:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
-//       CHECK:   %[[S2:.*]] = vector.shape_cast %[[E2]] : vector<1x2x2xf32> to vector<2x2xf32>
-//       CHECK:   %[[E3:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
-//       CHECK:   %[[S3:.*]] = vector.shape_cast %[[E3]] : vector<1x2x2xf32> to vector<2x2xf32>
-//       CHECK:   %[[ADD1:.*]] = arith.addf %[[S2]], %[[S3]] : vector<2x2xf32>
+//       CHECK:   %[[E_LHS_1:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S_LHS_1:.*]] = vector.shape_cast %[[E_LHS_1]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E_RHS_1:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S_RHS_1:.*]] = vector.shape_cast %[[E_RHS_1]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[ADD1:.*]] = arith.addf %[[S_LHS_1]], %[[S_RHS_1]] : vector<2x2xf32>
 //       CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2xf32>
 //       CHECK:   return %[[I1]] : vector<2x2x2xf32>

>From bb67a457f99c4a8e95b56945669ac1fb971b9dfe Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 14 Oct 2025 16:00:50 +0000
Subject: [PATCH 7/8] Add 4D to 2D test

---
 .../lib/Dialect/Vector/Transforms/VectorUnroll.cpp |  4 ++--
 .../test/Dialect/Vector/vector-unroll-options.mlir | 14 ++++++++++++++
 2 files changed, 16 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 3b0df2bab9717..fbae0989bed26 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -486,7 +486,7 @@ struct UnrollElementwisePattern : public RewritePattern {
     Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
                                              rewriter.getZeroAttr(dstVecType));
     SmallVector<int64_t> strides(adjustedTargetShapeRank, 1);
-    VectorType computeVecType =
+    VectorType unrolledVecType =
         VectorType::get(*targetShape, dstVecType.getElementType());
 
     // Create the unrolled computation.
@@ -512,7 +512,7 @@ struct UnrollElementwisePattern : public RewritePattern {
       }
 
       Operation *newOp = cloneOpWithOperandsAndTypes(
-          rewriter, loc, op, extractOperands, computeVecType);
+          rewriter, loc, op, extractOperands, unrolledVecType);
 
       Value computeResult = newOp->getResult(0);
 
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 9f5080777b7e7..8995da806b5aa 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -485,3 +485,17 @@ func.func @elementwise(%v1: vector<2x2x2xf32>, %v2: vector<2x2x2xf32>) -> vector
 //       CHECK:   %[[ADD1:.*]] = arith.addf %[[S_LHS_1]], %[[S_RHS_1]] : vector<2x2xf32>
 //       CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2xf32>
 //       CHECK:   return %[[I1]] : vector<2x2x2xf32>
+
+
+func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf32>) -> vector<2x2x2x2xf32> {
+  %0 = arith.addf %v1, %v2 : vector<2x2x2x2xf32>
+  return %0 : vector<2x2x2x2xf32>
+}
+
+//   CHECK-LABEL: func @elementwise_4D_to_2D
+//    CHECK-SAME: (%[[ARG0:.*]]: vector<2x2x2x2xf32>, %[[ARG1:.*]]: vector<2x2x2x2xf32>) -> vector<2x2x2x2xf32> {
+// CHECK-DAG-COUNT-4:   vector.extract_strided_slice %[[ARG0]] {offsets = [{{.*}}], sizes = [1, 1, 2, 2], strides = [1, 1, 1, 1]} : vector<2x2x2x2xf32> to vector<1x1x2x2xf32>
+// CHECK-DAG-COUNT-8:   vector.shape_cast {{.*}} : vector<1x1x2x2xf32> to vector<2x2xf32>
+// CHECK-DAG-COUNT-4:   vector.extract_strided_slice %[[ARG1]] {offsets = [{{.*}}], sizes = [1, 1, 2, 2], strides = [1, 1, 1, 1]} : vector<2x2x2x2xf32> to vector<1x1x2x2xf32>
+// CHECK-DAG-COUNT-4:   arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32>
+// CHECK-DAG-COUNT-4:   vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [{{.*}}], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2x2xf32>
\ No newline at end of file

>From c203150dfa395f51dd12ac874b006c84c629de4b Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 14 Oct 2025 16:07:05 +0000
Subject: [PATCH 8/8] Add newline

---
 mlir/test/Dialect/Vector/vector-unroll-options.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 8995da806b5aa..b44163ea2692e 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -498,4 +498,4 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3
 // CHECK-DAG-COUNT-8:   vector.shape_cast {{.*}} : vector<1x1x2x2xf32> to vector<2x2xf32>
 // CHECK-DAG-COUNT-4:   vector.extract_strided_slice %[[ARG1]] {offsets = [{{.*}}], sizes = [1, 1, 2, 2], strides = [1, 1, 1, 1]} : vector<2x2x2x2xf32> to vector<1x1x2x2xf32>
 // CHECK-DAG-COUNT-4:   arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32>
-// CHECK-DAG-COUNT-4:   vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [{{.*}}], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2x2xf32>
\ No newline at end of file
+// CHECK-DAG-COUNT-4:   vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [{{.*}}], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2x2xf32>



More information about the Mlir-commits mailing list