[Mlir-commits] [mlir] [MLIR][Vector] Allow unrolling to drop leading unit dimensions (PR #164710)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 22 13:59:29 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Nishant Patel (nbpatel)
<details>
<summary>Changes</summary>
This PR relaxes check in getTargetShape to allow unrolling to kick in for reducing dimensionality. It allows cases like [1, 1, N] to [N] to be unrolled (more like shape_cast).
---
Full diff: https://github.com/llvm/llvm-project/pull/164710.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp (+21-2)
- (modified) mlir/test/Dialect/Vector/vector-unroll-options.mlir (+15)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index fbae0989bed26..09f10b3ac952d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -122,8 +122,27 @@ getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
return std::nullopt;
}
if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
- LDBG() << "--no unrolling needed -> SKIP";
- return std::nullopt;
+ // If maybeShapeRatio are all 1s, only allow unrolling for leading unit
+ // dimension removal: [1,1,...,n] -> [n]
+ if (maybeUnrollShape->size() <= targetShape->size()) {
+ LDBG() << "--no dimension reduction -> SKIP";
+ return std::nullopt;
+ }
+
+ size_t dimDiff = maybeUnrollShape->size() - targetShape->size();
+ ArrayRef<int64_t> srcShape = *maybeUnrollShape;
+ ArrayRef<int64_t> tgtShape = *targetShape;
+
+ // Check leading dimensions are 1s and remaining matches target
+ bool isValidRemoval = llvm::all_of(srcShape.slice(0, dimDiff),
+ [](int64_t dim) { return dim == 1; }) &&
+ srcShape.slice(dimDiff) == tgtShape;
+
+ if (!isValidRemoval) {
+ LDBG() << "--not a valid leading unit dimension removal -> SKIP";
+ return std::nullopt;
+ }
+ LDBG() << "--leading unit dimension removal -> CONTINUE";
}
LDBG() << "--found an integral shape ratio to unroll to -> SUCCESS";
return targetShape;
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index e5a98b5c67f33..9fd77645b78b5 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -496,3 +496,18 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3
// CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32>
// CHECK-NOT: arith.addf
// CHECK: return
+
+
+func.func @elementwise_leading_unit_dim(%v1: vector<1x2x2xf32>, %v2: vector<1x2x2xf32>) -> vector<1x2x2xf32> {
+ %0 = arith.addf %v1, %v2 : vector<1x2x2xf32>
+ return %0 : vector<1x2x2xf32>
+}
+
+// CHECK-LABEL: func @elementwise_leading_unit_dim
+// CHECK-SAME: (%[[ARG0:.*]]: vector<1x2x2xf32>, %[[ARG1:.*]]: vector<1x2x2xf32>) -> vector<1x2x2xf32> {
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x2x2xf32>
+// CHECK: %[[S_LHS:.*]] = vector.shape_cast %[[ARG0]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[S_RHS:.*]] = vector.shape_cast %[[ARG1]] : vector<1x2x2xf32> to vector<2x2xf32>
+// CHECK: %[[ADD:.*]] = arith.addf %[[S_LHS]], %[[S_RHS]] : vector<2x2xf32>
+// CHECK: %[[INS:.*]] = vector.insert_strided_slice %[[ADD]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<1x2x2xf32>
+// CHECK: return %[[INS]] : vector<1x2x2xf32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/164710
More information about the Mlir-commits
mailing list