[Mlir-commits] [mlir] [mlir] Apply VectorFromElementsLowering in VectorToSPIRV. (PR #155499)

Erick Ochoa Lopez llvmlistbot at llvm.org
Tue Aug 26 14:11:43 PDT 2025


https://github.com/amd-eochoalo created https://github.com/llvm/llvm-project/pull/155499

Closes https://github.com/llvm/llvm-project/issues/155369

>From 2f04091266745d9c97fb38dec815fd12788a5102 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Tue, 26 Aug 2025 13:15:28 -0700
Subject: [PATCH] [mlir] Apply VectorFromElementsLowering in VectorToSPIRV.

---
 .../VectorToSPIRV/VectorToSPIRVPass.cpp          |  8 ++++++++
 .../VectorToSPIRV/vector-to-spirv.mlir           | 16 ++++++++++++++++
 2 files changed, 24 insertions(+)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
index b3ef23085c186..068bd4b216bce 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
@@ -37,6 +38,13 @@ void ConvertVectorToSPIRVPass::runOnOperation() {
   MLIRContext *context = &getContext();
   Operation *op = getOperation();
 
+  {
+    RewritePatternSet patterns(context);
+    vector::populateVectorFromElementsLoweringPatterns(patterns);
+    if (failed(applyPatternsGreedily(op, std::move(patterns))))
+      return signalPassFailure();
+  }
+
   auto targetAttr = spirv::lookupTargetEnvOrDefault(op);
   std::unique_ptr<ConversionTarget> target =
       SPIRVConversionTarget::get(targetAttr);
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 8918f91ef9145..a39d2509de363 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -308,6 +308,22 @@ func.func @from_elements_3x(%arg0 : f32, %arg1 : f32, %arg2 : f32) -> vector<3xf
   return %0: vector<3xf32>
 }
 
+// CHECK-LABEL: @from_elements_3d
+//  CHECK-SAME: %[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32, %[[ARG3:.+]]: f32
+func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x1x2xf32> {
+  // CHECK-DAG: %[[VEC3D:.+]] = ub.poison : vector<2x1x2xf32>
+  // CHECK-DAG: %[[VEC2D:.+]] = ub.poison : vector<1x2xf32>
+  // CHECK:     %[[VEC1_0:.+]] = spirv.CompositeConstruct %[[ARG0]], %[[ARG1]] : (f32, f32) -> vector<2xf32>
+  // CHECK:     %[[VEC1_1:.+]] = vector.insert %[[VEC1_0]], %[[VEC2D]] [0]
+  // CHECK:     %[[VEC1_2:.+]] = vector.insert %[[VEC1_1]], %[[VEC3D]] [0]
+  // CHECK:     %[[VEC2_0:.+]] = spirv.CompositeConstruct %[[ARG2]], %[[ARG3]] : (f32, f32) -> vector<2xf32>
+  // CHECK:     %[[VEC2_1:.+]] = vector.insert %[[VEC2_0]], %[[VEC2D]] [0]
+  // CHECK:     %[[VEC2_2:.+]] = vector.insert %[[VEC2_1]], %[[VEC1_2]] [1]
+  // CHECK:     return %[[VEC2_2]]
+  %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
+  return %0 : vector<2x1x2xf32>
+}
+
 // -----
 
 // CHECK-LABEL: @insert



More information about the Mlir-commits mailing list