[Mlir-commits] [mlir] [MLIR][Vector] Extend elementwise pattern to support unrolling from higher rank to lower rank (PR #162515)
Nishant Patel
llvmlistbot at llvm.org
Tue Oct 14 09:07:29 PDT 2025
https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/162515
>From f854b2dcf748a0ff9e5e39d5fae70e3cd437f774 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 8 Oct 2025 16:37:37 +0000
Subject: [PATCH 1/8] Extend elementwise to support unrolling from higher rank
to lower rank
---
.../Vector/Transforms/VectorUnroll.cpp | 52 +++++++++++-----
.../Dialect/Vector/vector-unroll-options.mlir | 60 +++++++++++++++++--
2 files changed, 92 insertions(+), 20 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 14639c5f1cdd3..62d65e28e8c2e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -468,23 +468,30 @@ struct UnrollElementwisePattern : public RewritePattern {
auto dstVecType = cast<VectorType>(op->getResult(0).getType());
SmallVector<int64_t> originalSize =
*cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
- // Bail-out if rank(source) != rank(target). The main limitation here is the
- // fact that `ExtractStridedSlice` requires the rank for the input and
- // output to match. If needed, we can relax this later.
- if (originalSize.size() != targetShape->size())
- return rewriter.notifyMatchFailure(
- op, "expected input vector rank to match target shape rank");
+
Location loc = op->getLoc();
+
+ // Handle rank mismatch by adding leading unit dimensions to targetShape
+ SmallVector<int64_t> adjustedTargetShape = *targetShape;
+ SmallVector<int64_t> adjustedOffsets;
+ if (originalSize.size() > targetShape->size()) {
+ // Add leading unit dimensions to targetShape
+ int64_t rankDiff = originalSize.size() - targetShape->size();
+ adjustedTargetShape.insert(adjustedTargetShape.begin(), rankDiff, 1);
+ }
+
// Prepare the result vector.
Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
rewriter.getZeroAttr(dstVecType));
- SmallVector<int64_t> strides(targetShape->size(), 1);
- VectorType newVecType =
+ SmallVector<int64_t> strides(adjustedTargetShape.size(), 1);
+ VectorType extractVecType =
+ VectorType::get(adjustedTargetShape, dstVecType.getElementType());
+ VectorType computeVecType =
VectorType::get(*targetShape, dstVecType.getElementType());
// Create the unrolled computation.
for (SmallVector<int64_t> offsets :
- StaticTileOffsetRange(originalSize, *targetShape)) {
+ StaticTileOffsetRange(originalSize, adjustedTargetShape)) {
SmallVector<Value> extractOperands;
for (OpOperand &operand : op->getOpOperands()) {
auto vecType = dyn_cast<VectorType>(operand.get().getType());
@@ -492,14 +499,31 @@ struct UnrollElementwisePattern : public RewritePattern {
extractOperands.push_back(operand.get());
continue;
}
- extractOperands.push_back(
- rewriter.createOrFold<vector::ExtractStridedSliceOp>(
- loc, operand.get(), offsets, *targetShape, strides));
+ Value extracted = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, operand.get(), offsets, adjustedTargetShape, strides);
+
+ // Reshape to remove leading unit dims if needed
+ if (adjustedTargetShape.size() > targetShape->size()) {
+ extracted = rewriter.createOrFold<vector::ShapeCastOp>(
+ loc, VectorType::get(*targetShape, vecType.getElementType()),
+ extracted);
+ }
+ extractOperands.push_back(extracted);
}
+
Operation *newOp = cloneOpWithOperandsAndTypes(
- rewriter, loc, op, extractOperands, newVecType);
+ rewriter, loc, op, extractOperands, computeVecType);
+
+ Value computeResult = newOp->getResult(0);
+
+ // Reshape back to higher rank if needed for insertion
+ if (adjustedTargetShape.size() > targetShape->size()) {
+ computeResult = rewriter.createOrFold<vector::ShapeCastOp>(
+ loc, extractVecType, computeResult);
+ }
+
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
- loc, newOp->getResult(0), result, offsets, strides);
+ loc, computeResult, result, offsets, strides);
}
rewriter.replaceOp(op, result);
return success();
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 35db14e0f7f1d..a26e4b0baa05b 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -188,15 +188,40 @@ func.func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf
// CHECK-LABEL: func @vector_fma
// CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>
-// TODO: We should be able to unroll this like the example above - this will require extending UnrollElementwisePattern.
-func.func @negative_vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
+func.func @vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
%0 = vector.fma %a, %a, %a : vector<3x2x2xf32>
return %0 : vector<3x2x2xf32>
}
-// CHECK-LABEL: func @negative_vector_fma_3d
-// CHECK-NOT: vector.extract_strided_slice
-// CHECK: %[[R0:.*]] = vector.fma %{{.+}} : vector<3x2x2xf32>
-// CHECK: return
+// CHECK-LABEL: func @vector_fma_3d
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<3x2x2xf32>
+// CHECK: %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S0:.*]] = vector.shape_cast %[[E0]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S1:.*]] = vector.shape_cast %[[E1]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S2:.*]] = vector.shape_cast %[[E2]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[FMA0:.*]] = vector.fma %[[S0]], %[[S1]], %[[S2]] : vector<2x2xf32>
+// CHECK: %[[SC0:.*]] = vector.shape_cast %[[FMA0]] : vector<2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<3x2x2xf32>
+// CHECK: %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S3:.*]] = vector.shape_cast %[[E3]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[E4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S4:.*]] = vector.shape_cast %[[E4]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[E5:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S5:.*]] = vector.shape_cast %[[E5]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[FMA1:.*]] = vector.fma %[[S3]], %[[S4]], %[[S5]] : vector<2x2xf32>
+// CHECK: %[[SC1:.*]] = vector.shape_cast %[[FMA1]] : vector<2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<3x2x2xf32>
+// CHECK: %[[E6:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S6:.*]] = vector.shape_cast %[[E6]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[E7:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S7:.*]] = vector.shape_cast %[[E7]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[E8:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S8:.*]] = vector.shape_cast %[[E8]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[FMA2:.*]] = vector.fma %[[S6]], %[[S7]], %[[S8]] : vector<2x2xf32>
+// CHECK: %[[SC2:.*]] = vector.shape_cast %[[FMA2]] : vector<2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[I2:.*]] = vector.insert_strided_slice %[[SC2]], %[[I1]] {offsets = [2, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<3x2x2xf32>
+// CHECK: return %[[I2]] : vector<3x2x2xf32>
func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
%0 = vector.multi_reduction #vector.kind<add>, %v, %acc [1] : vector<4x6xf32> to vector<4xf32>
@@ -440,3 +465,26 @@ func.func @vector_step() -> vector<32xindex> {
// CHECK: %[[ADD3:.*]] = arith.addi %[[STEP]], %[[CST]] : vector<8xindex>
// CHECK: %[[INS3:.*]] = vector.insert_strided_slice %[[ADD3]], %[[INS2]] {offsets = [24], strides = [1]} : vector<8xindex> into vector<32xindex>
// CHECK: return %[[INS3]] : vector<32xindex>
+
+
+func.func @elementwise(%v1: vector<2x2x2xf32>, %v2: vector<2x2x2xf32>) -> vector<2x2x2xf32> {
+ %0 = arith.addf %v1, %v2 : vector<2x2x2xf32>
+ return %0 : vector<2x2x2xf32>
+}
+// CHECK-LABEL: func @elementwise
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x2xf32>
+// CHECK: %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S0:.*]] = vector.shape_cast %[[E0]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S1:.*]] = vector.shape_cast %[[E1]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[ADD0:.*]] = arith.addf %[[S0]], %[[S1]] : vector<2x2xf32>
+// CHECK: %[[SC0:.*]] = vector.shape_cast %[[ADD0]] : vector<2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<2x2x2xf32>
+// CHECK: %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S2:.*]] = vector.shape_cast %[[E2]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S3:.*]] = vector.shape_cast %[[E3]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[ADD1:.*]] = arith.addf %[[S2]], %[[S3]] : vector<2x2xf32>
+// CHECK: %[[SC1:.*]] = vector.shape_cast %[[ADD1]] : vector<2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<2x2x2xf32>
+// CHECK: return %[[I1]] : vector<2x2x2xf32>
>From 01357d9b7dc5f0707ab8220bb7eb561935f88d35 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 9 Oct 2025 17:26:21 +0000
Subject: [PATCH 2/8] remove reshape before insertStride
---
.../Dialect/Vector/Transforms/VectorUnroll.cpp | 14 ++++++--------
.../Dialect/Vector/vector-unroll-options.mlir | 15 +++++----------
2 files changed, 11 insertions(+), 18 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 62d65e28e8c2e..dd01873ded05e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -484,8 +484,6 @@ struct UnrollElementwisePattern : public RewritePattern {
Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
rewriter.getZeroAttr(dstVecType));
SmallVector<int64_t> strides(adjustedTargetShape.size(), 1);
- VectorType extractVecType =
- VectorType::get(adjustedTargetShape, dstVecType.getElementType());
VectorType computeVecType =
VectorType::get(*targetShape, dstVecType.getElementType());
@@ -516,14 +514,14 @@ struct UnrollElementwisePattern : public RewritePattern {
Value computeResult = newOp->getResult(0);
- // Reshape back to higher rank if needed for insertion
- if (adjustedTargetShape.size() > targetShape->size()) {
- computeResult = rewriter.createOrFold<vector::ShapeCastOp>(
- loc, extractVecType, computeResult);
- }
+ // Use strides sized to targetShape for proper insertion
+ SmallVector<int64_t> insertStrides =
+ (adjustedTargetShape.size() > targetShape->size())
+ ? SmallVector<int64_t>(targetShape->size(), 1)
+ : strides;
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
- loc, computeResult, result, offsets, strides);
+ loc, computeResult, result, offsets, insertStrides);
}
rewriter.replaceOp(op, result);
return success();
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index a26e4b0baa05b..feea7c0f74531 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -201,8 +201,7 @@ func.func @vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
// CHECK: %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
// CHECK: %[[S2:.*]] = vector.shape_cast %[[E2]] : vector<1x2x2xf32> to vector<2x2xf32>
// CHECK: %[[FMA0:.*]] = vector.fma %[[S0]], %[[S1]], %[[S2]] : vector<2x2xf32>
-// CHECK: %[[SC0:.*]] = vector.shape_cast %[[FMA0]] : vector<2x2xf32> to vector<1x2x2xf32>
-// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<3x2x2xf32>
+// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[FMA0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
// CHECK: %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
// CHECK: %[[S3:.*]] = vector.shape_cast %[[E3]] : vector<1x2x2xf32> to vector<2x2xf32>
// CHECK: %[[E4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
@@ -210,8 +209,7 @@ func.func @vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
// CHECK: %[[E5:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
// CHECK: %[[S5:.*]] = vector.shape_cast %[[E5]] : vector<1x2x2xf32> to vector<2x2xf32>
// CHECK: %[[FMA1:.*]] = vector.fma %[[S3]], %[[S4]], %[[S5]] : vector<2x2xf32>
-// CHECK: %[[SC1:.*]] = vector.shape_cast %[[FMA1]] : vector<2x2xf32> to vector<1x2x2xf32>
-// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<3x2x2xf32>
+// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[FMA1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
// CHECK: %[[E6:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
// CHECK: %[[S6:.*]] = vector.shape_cast %[[E6]] : vector<1x2x2xf32> to vector<2x2xf32>
// CHECK: %[[E7:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
@@ -219,8 +217,7 @@ func.func @vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
// CHECK: %[[E8:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
// CHECK: %[[S8:.*]] = vector.shape_cast %[[E8]] : vector<1x2x2xf32> to vector<2x2xf32>
// CHECK: %[[FMA2:.*]] = vector.fma %[[S6]], %[[S7]], %[[S8]] : vector<2x2xf32>
-// CHECK: %[[SC2:.*]] = vector.shape_cast %[[FMA2]] : vector<2x2xf32> to vector<1x2x2xf32>
-// CHECK: %[[I2:.*]] = vector.insert_strided_slice %[[SC2]], %[[I1]] {offsets = [2, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<3x2x2xf32>
+// CHECK: %[[I2:.*]] = vector.insert_strided_slice %[[FMA2]], %[[I1]] {offsets = [2, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
// CHECK: return %[[I2]] : vector<3x2x2xf32>
func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
@@ -478,13 +475,11 @@ func.func @elementwise(%v1: vector<2x2x2xf32>, %v2: vector<2x2x2xf32>) -> vector
// CHECK: %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
// CHECK: %[[S1:.*]] = vector.shape_cast %[[E1]] : vector<1x2x2xf32> to vector<2x2xf32>
// CHECK: %[[ADD0:.*]] = arith.addf %[[S0]], %[[S1]] : vector<2x2xf32>
-// CHECK: %[[SC0:.*]] = vector.shape_cast %[[ADD0]] : vector<2x2xf32> to vector<1x2x2xf32>
-// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<2x2x2xf32>
+// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[ADD0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2xf32>
// CHECK: %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
// CHECK: %[[S2:.*]] = vector.shape_cast %[[E2]] : vector<1x2x2xf32> to vector<2x2xf32>
// CHECK: %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
// CHECK: %[[S3:.*]] = vector.shape_cast %[[E3]] : vector<1x2x2xf32> to vector<2x2xf32>
// CHECK: %[[ADD1:.*]] = arith.addf %[[S2]], %[[S3]] : vector<2x2xf32>
-// CHECK: %[[SC1:.*]] = vector.shape_cast %[[ADD1]] : vector<2x2xf32> to vector<1x2x2xf32>
-// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<2x2x2xf32>
+// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2xf32>
// CHECK: return %[[I1]] : vector<2x2x2xf32>
>From d59390bd29204430d5d64714a4f7792dbc1e8718 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 9 Oct 2025 17:43:01 +0000
Subject: [PATCH 3/8] Unused variable
---
mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index dd01873ded05e..613819232c69b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -473,7 +473,6 @@ struct UnrollElementwisePattern : public RewritePattern {
// Handle rank mismatch by adding leading unit dimensions to targetShape
SmallVector<int64_t> adjustedTargetShape = *targetShape;
- SmallVector<int64_t> adjustedOffsets;
if (originalSize.size() > targetShape->size()) {
// Add leading unit dimensions to targetShape
int64_t rankDiff = originalSize.size() - targetShape->size();
>From 7f567ecadfa16e53ee7da107f19cfa9a3b244de8 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 13 Oct 2025 21:16:21 +0000
Subject: [PATCH 4/8] Address feedback
---
.../Vector/Transforms/VectorUnroll.cpp | 25 ++++++-----
.../Dialect/Vector/vector-unroll-options.mlir | 43 ++++++++++---------
2 files changed, 36 insertions(+), 32 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 613819232c69b..3b0df2bab9717 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -465,24 +465,27 @@ struct UnrollElementwisePattern : public RewritePattern {
auto targetShape = getTargetShape(options, op);
if (!targetShape)
return failure();
+ int64_t targetShapeRank = targetShape->size();
auto dstVecType = cast<VectorType>(op->getResult(0).getType());
SmallVector<int64_t> originalSize =
*cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
+ int64_t originalShapeRank = originalSize.size();
Location loc = op->getLoc();
// Handle rank mismatch by adding leading unit dimensions to targetShape
- SmallVector<int64_t> adjustedTargetShape = *targetShape;
- if (originalSize.size() > targetShape->size()) {
- // Add leading unit dimensions to targetShape
- int64_t rankDiff = originalSize.size() - targetShape->size();
- adjustedTargetShape.insert(adjustedTargetShape.begin(), rankDiff, 1);
- }
-
+ SmallVector<int64_t> adjustedTargetShape(originalShapeRank);
+ int64_t rankDiff = originalShapeRank - targetShapeRank;
+ std::fill(adjustedTargetShape.begin(),
+ adjustedTargetShape.begin() + rankDiff, 1);
+ std::copy(targetShape->begin(), targetShape->end(),
+ adjustedTargetShape.begin() + rankDiff);
+
+ int64_t adjustedTargetShapeRank = adjustedTargetShape.size();
// Prepare the result vector.
Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
rewriter.getZeroAttr(dstVecType));
- SmallVector<int64_t> strides(adjustedTargetShape.size(), 1);
+ SmallVector<int64_t> strides(adjustedTargetShapeRank, 1);
VectorType computeVecType =
VectorType::get(*targetShape, dstVecType.getElementType());
@@ -500,7 +503,7 @@ struct UnrollElementwisePattern : public RewritePattern {
loc, operand.get(), offsets, adjustedTargetShape, strides);
// Reshape to remove leading unit dims if needed
- if (adjustedTargetShape.size() > targetShape->size()) {
+ if (adjustedTargetShapeRank > targetShapeRank) {
extracted = rewriter.createOrFold<vector::ShapeCastOp>(
loc, VectorType::get(*targetShape, vecType.getElementType()),
extracted);
@@ -515,8 +518,8 @@ struct UnrollElementwisePattern : public RewritePattern {
// Use strides sized to targetShape for proper insertion
SmallVector<int64_t> insertStrides =
- (adjustedTargetShape.size() > targetShape->size())
- ? SmallVector<int64_t>(targetShape->size(), 1)
+ (adjustedTargetShapeRank > targetShapeRank)
+ ? SmallVector<int64_t>(targetShapeRank, 1)
: strides;
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index feea7c0f74531..ca06c037c4b11 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -193,30 +193,31 @@ func.func @vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
return %0 : vector<3x2x2xf32>
}
// CHECK-LABEL: func @vector_fma_3d
+// CHECK-SAME: (%[[SRC:.*]]: vector<3x2x2xf32>) -> vector<3x2x2xf32> {
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<3x2x2xf32>
-// CHECK: %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
-// CHECK: %[[S0:.*]] = vector.shape_cast %[[E0]] : vector<1x2x2xf32> to vector<2x2xf32>
-// CHECK: %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
-// CHECK: %[[S1:.*]] = vector.shape_cast %[[E1]] : vector<1x2x2xf32> to vector<2x2xf32>
-// CHECK: %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
-// CHECK: %[[S2:.*]] = vector.shape_cast %[[E2]] : vector<1x2x2xf32> to vector<2x2xf32>
-// CHECK: %[[FMA0:.*]] = vector.fma %[[S0]], %[[S1]], %[[S2]] : vector<2x2xf32>
+// CHECK: %[[E_LHS_0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S_LHS_0:.*]] = vector.shape_cast %[[E_LHS_0]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[E_RHS_0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S_RHS_0:.*]] = vector.shape_cast %[[E_RHS_0]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[E_OUT_0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S_OUT_0:.*]] = vector.shape_cast %[[E_OUT_0]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[FMA0:.*]] = vector.fma %[[S_LHS_0]], %[[S_RHS_0]], %[[S_OUT_0]] : vector<2x2xf32>
// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[FMA0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
-// CHECK: %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
-// CHECK: %[[S3:.*]] = vector.shape_cast %[[E3]] : vector<1x2x2xf32> to vector<2x2xf32>
-// CHECK: %[[E4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
-// CHECK: %[[S4:.*]] = vector.shape_cast %[[E4]] : vector<1x2x2xf32> to vector<2x2xf32>
-// CHECK: %[[E5:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
-// CHECK: %[[S5:.*]] = vector.shape_cast %[[E5]] : vector<1x2x2xf32> to vector<2x2xf32>
-// CHECK: %[[FMA1:.*]] = vector.fma %[[S3]], %[[S4]], %[[S5]] : vector<2x2xf32>
+// CHECK: %[[E_LHS_1:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S_LHS_1:.*]] = vector.shape_cast %[[E_LHS_1]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[E_RHS_1:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S_RHS_1:.*]] = vector.shape_cast %[[E_RHS_1]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[E_OUT_1:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S_OUT_1:.*]] = vector.shape_cast %[[E_OUT_1]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[FMA1:.*]] = vector.fma %[[S_LHS_1]], %[[S_RHS_1]], %[[S_OUT_1]] : vector<2x2xf32>
// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[FMA1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
-// CHECK: %[[E6:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
-// CHECK: %[[S6:.*]] = vector.shape_cast %[[E6]] : vector<1x2x2xf32> to vector<2x2xf32>
-// CHECK: %[[E7:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
-// CHECK: %[[S7:.*]] = vector.shape_cast %[[E7]] : vector<1x2x2xf32> to vector<2x2xf32>
-// CHECK: %[[E8:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
-// CHECK: %[[S8:.*]] = vector.shape_cast %[[E8]] : vector<1x2x2xf32> to vector<2x2xf32>
-// CHECK: %[[FMA2:.*]] = vector.fma %[[S6]], %[[S7]], %[[S8]] : vector<2x2xf32>
+// CHECK: %[[E_LHS_2:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S_LHS_2:.*]] = vector.shape_cast %[[E_LHS_2]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[E_RHS_2:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S_RHS_2:.*]] = vector.shape_cast %[[E_RHS_2]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[E_OUT_2:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S_OUT_2:.*]] = vector.shape_cast %[[E_OUT_2]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[FMA2:.*]] = vector.fma %[[S_LHS_2]], %[[S_RHS_2]], %[[S_OUT_2]] : vector<2x2xf32>
// CHECK: %[[I2:.*]] = vector.insert_strided_slice %[[FMA2]], %[[I1]] {offsets = [2, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32>
// CHECK: return %[[I2]] : vector<3x2x2xf32>
>From 5052b469d0a22188054c0b3d1c3716ee3589f3a8 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 14 Oct 2025 00:30:23 +0000
Subject: [PATCH 5/8] Fix CHECK
---
mlir/test/Dialect/Vector/vector-unroll-options.mlir | 9 +++++----
1 file changed, 5 insertions(+), 4 deletions(-)
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index ca06c037c4b11..afec43101132e 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -470,16 +470,17 @@ func.func @elementwise(%v1: vector<2x2x2xf32>, %v2: vector<2x2x2xf32>) -> vector
return %0 : vector<2x2x2xf32>
}
// CHECK-LABEL: func @elementwise
+// CHECK-SAME: (%[[ARG0:.*]]: vector<2x2x2xf32>, %[[ARG1:.*]]: vector<2x2x2xf32>) -> vector<2x2x2xf32> {
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x2xf32>
-// CHECK: %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[E0:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
// CHECK: %[[S0:.*]] = vector.shape_cast %[[E0]] : vector<1x2x2xf32> to vector<2x2xf32>
-// CHECK: %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[E1:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
// CHECK: %[[S1:.*]] = vector.shape_cast %[[E1]] : vector<1x2x2xf32> to vector<2x2xf32>
// CHECK: %[[ADD0:.*]] = arith.addf %[[S0]], %[[S1]] : vector<2x2xf32>
// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[ADD0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2xf32>
-// CHECK: %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[E2:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
// CHECK: %[[S2:.*]] = vector.shape_cast %[[E2]] : vector<1x2x2xf32> to vector<2x2xf32>
-// CHECK: %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[E3:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
// CHECK: %[[S3:.*]] = vector.shape_cast %[[E3]] : vector<1x2x2xf32> to vector<2x2xf32>
// CHECK: %[[ADD1:.*]] = arith.addf %[[S2]], %[[S3]] : vector<2x2xf32>
// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2xf32>
>From 61c6edaab8bb687f094f62c842d7a8be701d0d51 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 14 Oct 2025 01:26:43 +0000
Subject: [PATCH 6/8] Rename CHECK variable
---
.../Dialect/Vector/vector-unroll-options.mlir | 20 +++++++++----------
1 file changed, 10 insertions(+), 10 deletions(-)
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index afec43101132e..9f5080777b7e7 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -472,16 +472,16 @@ func.func @elementwise(%v1: vector<2x2x2xf32>, %v2: vector<2x2x2xf32>) -> vector
// CHECK-LABEL: func @elementwise
// CHECK-SAME: (%[[ARG0:.*]]: vector<2x2x2xf32>, %[[ARG1:.*]]: vector<2x2x2xf32>) -> vector<2x2x2xf32> {
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x2xf32>
-// CHECK: %[[E0:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
-// CHECK: %[[S0:.*]] = vector.shape_cast %[[E0]] : vector<1x2x2xf32> to vector<2x2xf32>
-// CHECK: %[[E1:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
-// CHECK: %[[S1:.*]] = vector.shape_cast %[[E1]] : vector<1x2x2xf32> to vector<2x2xf32>
-// CHECK: %[[ADD0:.*]] = arith.addf %[[S0]], %[[S1]] : vector<2x2xf32>
+// CHECK: %[[E_LHS_0:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S_LHS_0:.*]] = vector.shape_cast %[[E_LHS_0]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[E_RHS_0:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S_RHS_0:.*]] = vector.shape_cast %[[E_RHS_0]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[ADD0:.*]] = arith.addf %[[S_LHS_0]], %[[S_RHS_0]] : vector<2x2xf32>
// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[ADD0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2xf32>
-// CHECK: %[[E2:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
-// CHECK: %[[S2:.*]] = vector.shape_cast %[[E2]] : vector<1x2x2xf32> to vector<2x2xf32>
-// CHECK: %[[E3:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
-// CHECK: %[[S3:.*]] = vector.shape_cast %[[E3]] : vector<1x2x2xf32> to vector<2x2xf32>
-// CHECK: %[[ADD1:.*]] = arith.addf %[[S2]], %[[S3]] : vector<2x2xf32>
+// CHECK: %[[E_LHS_1:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S_LHS_1:.*]] = vector.shape_cast %[[E_LHS_1]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[E_RHS_1:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S_RHS_1:.*]] = vector.shape_cast %[[E_RHS_1]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[ADD1:.*]] = arith.addf %[[S_LHS_1]], %[[S_RHS_1]] : vector<2x2xf32>
// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2xf32>
// CHECK: return %[[I1]] : vector<2x2x2xf32>
>From bb67a457f99c4a8e95b56945669ac1fb971b9dfe Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 14 Oct 2025 16:00:50 +0000
Subject: [PATCH 7/8] Add 4D to 2D test
---
.../lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 4 ++--
.../test/Dialect/Vector/vector-unroll-options.mlir | 14 ++++++++++++++
2 files changed, 16 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 3b0df2bab9717..fbae0989bed26 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -486,7 +486,7 @@ struct UnrollElementwisePattern : public RewritePattern {
Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
rewriter.getZeroAttr(dstVecType));
SmallVector<int64_t> strides(adjustedTargetShapeRank, 1);
- VectorType computeVecType =
+ VectorType unrolledVecType =
VectorType::get(*targetShape, dstVecType.getElementType());
// Create the unrolled computation.
@@ -512,7 +512,7 @@ struct UnrollElementwisePattern : public RewritePattern {
}
Operation *newOp = cloneOpWithOperandsAndTypes(
- rewriter, loc, op, extractOperands, computeVecType);
+ rewriter, loc, op, extractOperands, unrolledVecType);
Value computeResult = newOp->getResult(0);
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 9f5080777b7e7..8995da806b5aa 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -485,3 +485,17 @@ func.func @elementwise(%v1: vector<2x2x2xf32>, %v2: vector<2x2x2xf32>) -> vector
// CHECK: %[[ADD1:.*]] = arith.addf %[[S_LHS_1]], %[[S_RHS_1]] : vector<2x2xf32>
// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2xf32>
// CHECK: return %[[I1]] : vector<2x2x2xf32>
+
+
+func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf32>) -> vector<2x2x2x2xf32> {
+ %0 = arith.addf %v1, %v2 : vector<2x2x2x2xf32>
+ return %0 : vector<2x2x2x2xf32>
+}
+
+// CHECK-LABEL: func @elementwise_4D_to_2D
+// CHECK-SAME: (%[[ARG0:.*]]: vector<2x2x2x2xf32>, %[[ARG1:.*]]: vector<2x2x2x2xf32>) -> vector<2x2x2x2xf32> {
+// CHECK-DAG-COUNT-4: vector.extract_strided_slice %[[ARG0]] {offsets = [{{.*}}], sizes = [1, 1, 2, 2], strides = [1, 1, 1, 1]} : vector<2x2x2x2xf32> to vector<1x1x2x2xf32>
+// CHECK-DAG-COUNT-8: vector.shape_cast {{.*}} : vector<1x1x2x2xf32> to vector<2x2xf32>
+// CHECK-DAG-COUNT-4: vector.extract_strided_slice %[[ARG1]] {offsets = [{{.*}}], sizes = [1, 1, 2, 2], strides = [1, 1, 1, 1]} : vector<2x2x2x2xf32> to vector<1x1x2x2xf32>
+// CHECK-DAG-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32>
+// CHECK-DAG-COUNT-4: vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [{{.*}}], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2x2xf32>
\ No newline at end of file
>From c203150dfa395f51dd12ac874b006c84c629de4b Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 14 Oct 2025 16:07:05 +0000
Subject: [PATCH 8/8] Add newline
---
mlir/test/Dialect/Vector/vector-unroll-options.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 8995da806b5aa..b44163ea2692e 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -498,4 +498,4 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3
// CHECK-DAG-COUNT-8: vector.shape_cast {{.*}} : vector<1x1x2x2xf32> to vector<2x2xf32>
// CHECK-DAG-COUNT-4: vector.extract_strided_slice %[[ARG1]] {offsets = [{{.*}}], sizes = [1, 1, 2, 2], strides = [1, 1, 1, 1]} : vector<2x2x2x2xf32> to vector<1x1x2x2xf32>
// CHECK-DAG-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32>
-// CHECK-DAG-COUNT-4: vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [{{.*}}], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2x2xf32>
\ No newline at end of file
+// CHECK-DAG-COUNT-4: vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [{{.*}}], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2x2xf32>
More information about the Mlir-commits
mailing list