[Mlir-commits] [mlir] [mlir] Add vector.{to_elements, from_elements} unrolling to `unrollVectorsInFuncBodies` (PR #159118)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 16 09:15:15 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Erick Ochoa Lopez (amd-eochoalo)
<details>
<summary>Changes</summary>
This patch adds the unrolling patterns for vector.to_elements and vector.from_elements to the action `unrollVectorsInFuncBodies`. This affects the passes `test-convert-to-spriv` (when running with option `run-vector-unrolling=true` ) and `test-spirv-vector-unrolling`.
---
Full diff: https://github.com/llvm/llvm-project/pull/159118.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+2)
- (modified) mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir (+44)
``````````diff
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 49f4ce8de7c76..98e294b40456f 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1495,6 +1495,8 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
RewritePatternSet patterns(context);
auto options = vector::UnrollVectorOptions().setNativeShapeFn(
[](auto op) { return mlir::spirv::getNativeVectorShape(op); });
+ vector::populateVectorFromElementsLoweringPatterns(patterns);
+ vector::populateVectorToElementsLoweringPatterns(patterns);
populateVectorUnrollPatterns(patterns, options);
if (failed(applyPatternsGreedily(op, std::move(patterns))))
return failure();
diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
index c85f4334ff2e5..0957f67690b97 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
@@ -96,3 +96,47 @@ func.func @transpose(%arg0 : vector<2x3xi32>) -> (vector<3x2xi32>) {
%0 = vector.transpose %arg0, [1, 0] : vector<2x3xi32> to vector<3x2xi32>
return %0 : vector<3x2xi32>
}
+
+// -----
+
+// In order to verify that the pattern is applied,
+// we need to make sure that the the 2d vector does not
+// come from the parameters. Otherwise, the pattern
+// in unrollVectorsInSignatures which splits the 2d vector
+// parameter will take precedent. Similarly, let's avoid
+// returning a vector as another pattern would take precendence.
+
+// CHECK-LABEL: @unroll_to_elements_2d
+func.func @unroll_to_elements_2d() -> (f32, f32, f32, f32) {
+ %1 = "test.op"() : () -> (vector<2x2xf32>)
+ // CHECK: %[[VEC2D:.+]] = "test.op"
+ // CHECK: %[[VEC0:.+]] = vector.extract %[[VEC2D]][0] : vector<2xf32> from vector<2x2xf32>
+ // CHECK: %[[VEC1:.+]] = vector.extract %[[VEC2D]][1] : vector<2xf32> from vector<2x2xf32>
+ // CHECK: %[[RES0:.+]]:2 = vector.to_elements %[[VEC0]]
+ // CHECK: %[[RES1:.+]]:2 = vector.to_elements %[[VEC1]]
+ %2:4 = vector.to_elements %1 : vector<2x2xf32>
+ return %2#0, %2#1, %2#2, %2#3 : f32, f32, f32, f32
+}
+
+// -----
+
+// In order to verify that the pattern is applied,
+// we need to make sure that the the 2d vector is used
+// by an operation and that extracts are not folded away.
+// In other words we can't use "test.op" nor return the
+// value `%0 = vector.from_elements`
+
+// CHECK-LABEL: @unroll_from_elements_2d
+// CHECK-SAME: (%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32, %[[ARG3:.+]]: f32)
+func.func @unroll_from_elements_2d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> (vector<2x2xf32>) {
+ // CHECK: %[[VEC0:.+]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32>
+ // CHECK: %[[VEC1:.+]] = vector.from_elements %[[ARG2]], %[[ARG3]] : vector<2xf32>
+ %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32>
+
+ // CHECK: %[[RES0:.+]] = arith.addf %[[VEC0]], %[[VEC0]]
+ // CHECK: %[[RES1:.+]] = arith.addf %[[VEC1]], %[[VEC1]]
+ %1 = arith.addf %0, %0 : vector<2x2xf32>
+
+ // return %[[RES0]], %%[[RES1]] : vector<2xf32>, vector<2xf32>
+ return %1 : vector<2x2xf32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/159118
More information about the Mlir-commits
mailing list