[Mlir-commits] [mlir] [mlir] Apply VectorFromElementsLowering in VectorToSPIRV. (PR #155499)
Erick Ochoa Lopez
llvmlistbot at llvm.org
Tue Aug 26 14:11:43 PDT 2025
https://github.com/amd-eochoalo created https://github.com/llvm/llvm-project/pull/155499
Closes https://github.com/llvm/llvm-project/issues/155369
>From 2f04091266745d9c97fb38dec815fd12788a5102 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 26 Aug 2025 13:15:28 -0700
Subject: [PATCH] [mlir] Apply VectorFromElementsLowering in VectorToSPIRV.
---
.../VectorToSPIRV/VectorToSPIRVPass.cpp | 8 ++++++++
.../VectorToSPIRV/vector-to-spirv.mlir | 16 ++++++++++++++++
2 files changed, 24 insertions(+)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
index b3ef23085c186..068bd4b216bce 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
@@ -37,6 +38,13 @@ void ConvertVectorToSPIRVPass::runOnOperation() {
MLIRContext *context = &getContext();
Operation *op = getOperation();
+ {
+ RewritePatternSet patterns(context);
+ vector::populateVectorFromElementsLoweringPatterns(patterns);
+ if (failed(applyPatternsGreedily(op, std::move(patterns))))
+ return signalPassFailure();
+ }
+
auto targetAttr = spirv::lookupTargetEnvOrDefault(op);
std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 8918f91ef9145..a39d2509de363 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -308,6 +308,22 @@ func.func @from_elements_3x(%arg0 : f32, %arg1 : f32, %arg2 : f32) -> vector<3xf
return %0: vector<3xf32>
}
+// CHECK-LABEL: @from_elements_3d
+// CHECK-SAME: %[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32, %[[ARG3:.+]]: f32
+func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x1x2xf32> {
+ // CHECK-DAG: %[[VEC3D:.+]] = ub.poison : vector<2x1x2xf32>
+ // CHECK-DAG: %[[VEC2D:.+]] = ub.poison : vector<1x2xf32>
+ // CHECK: %[[VEC1_0:.+]] = spirv.CompositeConstruct %[[ARG0]], %[[ARG1]] : (f32, f32) -> vector<2xf32>
+ // CHECK: %[[VEC1_1:.+]] = vector.insert %[[VEC1_0]], %[[VEC2D]] [0]
+ // CHECK: %[[VEC1_2:.+]] = vector.insert %[[VEC1_1]], %[[VEC3D]] [0]
+ // CHECK: %[[VEC2_0:.+]] = spirv.CompositeConstruct %[[ARG2]], %[[ARG3]] : (f32, f32) -> vector<2xf32>
+ // CHECK: %[[VEC2_1:.+]] = vector.insert %[[VEC2_0]], %[[VEC2D]] [0]
+ // CHECK: %[[VEC2_2:.+]] = vector.insert %[[VEC2_1]], %[[VEC1_2]] [1]
+ // CHECK: return %[[VEC2_2]]
+ %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
+ return %0 : vector<2x1x2xf32>
+}
+
// -----
// CHECK-LABEL: @insert
More information about the Mlir-commits
mailing list