[Mlir-commits] [mlir] [MLIR][Vector] Extend elementwise pattern to support unrolling from higher rank to lower rank (PR #162515)
Nishant Patel
llvmlistbot at llvm.org
Wed Oct 8 10:08:00 PDT 2025
https://github.com/nbpatel created https://github.com/llvm/llvm-project/pull/162515
None
>From f854b2dcf748a0ff9e5e39d5fae70e3cd437f774 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 8 Oct 2025 16:37:37 +0000
Subject: [PATCH] Extend elementwise to support unrolling from higher rank to
lower rank
---
.../Vector/Transforms/VectorUnroll.cpp | 52 +++++++++++-----
.../Dialect/Vector/vector-unroll-options.mlir | 60 +++++++++++++++++--
2 files changed, 92 insertions(+), 20 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 14639c5f1cdd3..62d65e28e8c2e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -468,23 +468,30 @@ struct UnrollElementwisePattern : public RewritePattern {
auto dstVecType = cast<VectorType>(op->getResult(0).getType());
SmallVector<int64_t> originalSize =
*cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
- // Bail-out if rank(source) != rank(target). The main limitation here is the
- // fact that `ExtractStridedSlice` requires the rank for the input and
- // output to match. If needed, we can relax this later.
- if (originalSize.size() != targetShape->size())
- return rewriter.notifyMatchFailure(
- op, "expected input vector rank to match target shape rank");
+
Location loc = op->getLoc();
+
+ // Handle rank mismatch by adding leading unit dimensions to targetShape
+ SmallVector<int64_t> adjustedTargetShape = *targetShape;
+ SmallVector<int64_t> adjustedOffsets;
+ if (originalSize.size() > targetShape->size()) {
+ // Add leading unit dimensions to targetShape
+ int64_t rankDiff = originalSize.size() - targetShape->size();
+ adjustedTargetShape.insert(adjustedTargetShape.begin(), rankDiff, 1);
+ }
+
// Prepare the result vector.
Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
rewriter.getZeroAttr(dstVecType));
- SmallVector<int64_t> strides(targetShape->size(), 1);
- VectorType newVecType =
+ SmallVector<int64_t> strides(adjustedTargetShape.size(), 1);
+ VectorType extractVecType =
+ VectorType::get(adjustedTargetShape, dstVecType.getElementType());
+ VectorType computeVecType =
VectorType::get(*targetShape, dstVecType.getElementType());
// Create the unrolled computation.
for (SmallVector<int64_t> offsets :
- StaticTileOffsetRange(originalSize, *targetShape)) {
+ StaticTileOffsetRange(originalSize, adjustedTargetShape)) {
SmallVector<Value> extractOperands;
for (OpOperand &operand : op->getOpOperands()) {
auto vecType = dyn_cast<VectorType>(operand.get().getType());
@@ -492,14 +499,31 @@ struct UnrollElementwisePattern : public RewritePattern {
extractOperands.push_back(operand.get());
continue;
}
- extractOperands.push_back(
- rewriter.createOrFold<vector::ExtractStridedSliceOp>(
- loc, operand.get(), offsets, *targetShape, strides));
+ Value extracted = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, operand.get(), offsets, adjustedTargetShape, strides);
+
+ // Reshape to remove leading unit dims if needed
+ if (adjustedTargetShape.size() > targetShape->size()) {
+ extracted = rewriter.createOrFold<vector::ShapeCastOp>(
+ loc, VectorType::get(*targetShape, vecType.getElementType()),
+ extracted);
+ }
+ extractOperands.push_back(extracted);
}
+
Operation *newOp = cloneOpWithOperandsAndTypes(
- rewriter, loc, op, extractOperands, newVecType);
+ rewriter, loc, op, extractOperands, computeVecType);
+
+ Value computeResult = newOp->getResult(0);
+
+ // Reshape back to higher rank if needed for insertion
+ if (adjustedTargetShape.size() > targetShape->size()) {
+ computeResult = rewriter.createOrFold<vector::ShapeCastOp>(
+ loc, extractVecType, computeResult);
+ }
+
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
- loc, newOp->getResult(0), result, offsets, strides);
+ loc, computeResult, result, offsets, strides);
}
rewriter.replaceOp(op, result);
return success();
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 35db14e0f7f1d..a26e4b0baa05b 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -188,15 +188,40 @@ func.func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf
// CHECK-LABEL: func @vector_fma
// CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>
-// TODO: We should be able to unroll this like the example above - this will require extending UnrollElementwisePattern.
-func.func @negative_vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
+func.func @vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
%0 = vector.fma %a, %a, %a : vector<3x2x2xf32>
return %0 : vector<3x2x2xf32>
}
-// CHECK-LABEL: func @negative_vector_fma_3d
-// CHECK-NOT: vector.extract_strided_slice
-// CHECK: %[[R0:.*]] = vector.fma %{{.+}} : vector<3x2x2xf32>
-// CHECK: return
+// CHECK-LABEL: func @vector_fma_3d
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<3x2x2xf32>
+// CHECK: %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S0:.*]] = vector.shape_cast %[[E0]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S1:.*]] = vector.shape_cast %[[E1]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S2:.*]] = vector.shape_cast %[[E2]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[FMA0:.*]] = vector.fma %[[S0]], %[[S1]], %[[S2]] : vector<2x2xf32>
+// CHECK: %[[SC0:.*]] = vector.shape_cast %[[FMA0]] : vector<2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<3x2x2xf32>
+// CHECK: %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S3:.*]] = vector.shape_cast %[[E3]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[E4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S4:.*]] = vector.shape_cast %[[E4]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[E5:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S5:.*]] = vector.shape_cast %[[E5]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[FMA1:.*]] = vector.fma %[[S3]], %[[S4]], %[[S5]] : vector<2x2xf32>
+// CHECK: %[[SC1:.*]] = vector.shape_cast %[[FMA1]] : vector<2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<3x2x2xf32>
+// CHECK: %[[E6:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S6:.*]] = vector.shape_cast %[[E6]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[E7:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S7:.*]] = vector.shape_cast %[[E7]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[E8:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S8:.*]] = vector.shape_cast %[[E8]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[FMA2:.*]] = vector.fma %[[S6]], %[[S7]], %[[S8]] : vector<2x2xf32>
+// CHECK: %[[SC2:.*]] = vector.shape_cast %[[FMA2]] : vector<2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[I2:.*]] = vector.insert_strided_slice %[[SC2]], %[[I1]] {offsets = [2, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<3x2x2xf32>
+// CHECK: return %[[I2]] : vector<3x2x2xf32>
func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
%0 = vector.multi_reduction #vector.kind<add>, %v, %acc [1] : vector<4x6xf32> to vector<4xf32>
@@ -440,3 +465,26 @@ func.func @vector_step() -> vector<32xindex> {
// CHECK: %[[ADD3:.*]] = arith.addi %[[STEP]], %[[CST]] : vector<8xindex>
// CHECK: %[[INS3:.*]] = vector.insert_strided_slice %[[ADD3]], %[[INS2]] {offsets = [24], strides = [1]} : vector<8xindex> into vector<32xindex>
// CHECK: return %[[INS3]] : vector<32xindex>
+
+
+func.func @elementwise(%v1: vector<2x2x2xf32>, %v2: vector<2x2x2xf32>) -> vector<2x2x2xf32> {
+ %0 = arith.addf %v1, %v2 : vector<2x2x2xf32>
+ return %0 : vector<2x2x2xf32>
+}
+// CHECK-LABEL: func @elementwise
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x2xf32>
+// CHECK: %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S0:.*]] = vector.shape_cast %[[E0]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S1:.*]] = vector.shape_cast %[[E1]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[ADD0:.*]] = arith.addf %[[S0]], %[[S1]] : vector<2x2xf32>
+// CHECK: %[[SC0:.*]] = vector.shape_cast %[[ADD0]] : vector<2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<2x2x2xf32>
+// CHECK: %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S2:.*]] = vector.shape_cast %[[E2]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[S3:.*]] = vector.shape_cast %[[E3]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[ADD1:.*]] = arith.addf %[[S2]], %[[S3]] : vector<2x2xf32>
+// CHECK: %[[SC1:.*]] = vector.shape_cast %[[ADD1]] : vector<2x2xf32> to vector<1x2x2xf32>
+// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<2x2x2xf32>
+// CHECK: return %[[I1]] : vector<2x2x2xf32>
More information about the Mlir-commits
mailing list