[Mlir-commits] [mlir] [mlir][spirv][vector] Support converting vector.from_elements to SPIR-V (PR #118540)

Andrea Faulds llvmlistbot at llvm.org
Wed Dec 4 08:11:40 PST 2024


https://github.com/andfau-amd updated https://github.com/llvm/llvm-project/pull/118540

>From c4bdeffabad0196be017fb40df20a52019f10ce6 Mon Sep 17 00:00:00 2001
From: Andrea Faulds <andrea.faulds at amd.com>
Date: Mon, 2 Dec 2024 17:35:01 +0100
Subject: [PATCH 1/3] 
 73457-runner-migration-vulkan-runner-IR-dump-hacks-2024-12-02

---
 mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
index 2dd539ef83481f..51b945ce37e486 100644
--- a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
+++ b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
@@ -74,7 +74,7 @@ static LogicalResult runMLIRPasses(Operation *op,
   if (options.spirvWebGPUPrepare)
     modulePM.addPass(spirv::createSPIRVWebGPUPreparePass());
 
-  passManager.addPass(createConvertGpuLaunchFuncToVulkanLaunchFuncPass());
+  /*passManager.addPass(createConvertGpuLaunchFuncToVulkanLaunchFuncPass());
   passManager.addPass(createFinalizeMemRefToLLVMConversionPass());
   passManager.addPass(createConvertVectorToLLVMPass());
   passManager.nest<func::FuncOp>().addPass(LLVM::createRequestCWrappersPass());
@@ -82,10 +82,12 @@ static LogicalResult runMLIRPasses(Operation *op,
   funcToLLVMOptions.indexBitwidth =
       DataLayout(module).getTypeSizeInBits(IndexType::get(module.getContext()));
   passManager.addPass(createConvertFuncToLLVMPass(funcToLLVMOptions));
-  passManager.addPass(createReconcileUnrealizedCastsPass());
-  passManager.addPass(createConvertVulkanLaunchFuncToVulkanCallsPass());
+  passManager.addPass(createReconcileUnrealizedCastsPass());*/
+  //passManager.addPass(createConvertVulkanLaunchFuncToVulkanCallsPass());
 
-  return passManager.run(module);
+  auto res= passManager.run(module);//FIXME
+  module.dump();
+  return res;
 }
 
 int main(int argc, char **argv) {

>From 538e83236159aea193eb31ecbd27b2006a782b27 Mon Sep 17 00:00:00 2001
From: Andrea Faulds <andrea.faulds at amd.com>
Date: Wed, 4 Dec 2024 16:34:52 +0100
Subject: [PATCH 2/3] 2024-12-04 dump IR after each phase of vulkan runner
 instead

---
 .../tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp | 13 ++++++++-----
 1 file changed, 8 insertions(+), 5 deletions(-)

diff --git a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
index 51b945ce37e486..2b3524973c970e 100644
--- a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
+++ b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
@@ -58,9 +58,12 @@ static LogicalResult runMLIRPasses(Operation *op,
   auto module = dyn_cast<ModuleOp>(op);
   if (!module)
     return op->emitOpError("expected a 'builtin.module' op");
-  PassManager passManager(module.getContext());
+  auto ctx = module.getContext();
+  PassManager passManager(ctx);
   if (failed(applyPassManagerCLOptions(passManager)))
     return failure();
+  ctx->disableMultithreading();
+  passManager.enableIRPrinting();
 
   passManager.addPass(createGpuKernelOutliningPass());
   passManager.addPass(memref::createFoldMemRefAliasOpsPass());
@@ -74,7 +77,7 @@ static LogicalResult runMLIRPasses(Operation *op,
   if (options.spirvWebGPUPrepare)
     modulePM.addPass(spirv::createSPIRVWebGPUPreparePass());
 
-  /*passManager.addPass(createConvertGpuLaunchFuncToVulkanLaunchFuncPass());
+  passManager.addPass(createConvertGpuLaunchFuncToVulkanLaunchFuncPass());
   passManager.addPass(createFinalizeMemRefToLLVMConversionPass());
   passManager.addPass(createConvertVectorToLLVMPass());
   passManager.nest<func::FuncOp>().addPass(LLVM::createRequestCWrappersPass());
@@ -82,11 +85,11 @@ static LogicalResult runMLIRPasses(Operation *op,
   funcToLLVMOptions.indexBitwidth =
       DataLayout(module).getTypeSizeInBits(IndexType::get(module.getContext()));
   passManager.addPass(createConvertFuncToLLVMPass(funcToLLVMOptions));
-  passManager.addPass(createReconcileUnrealizedCastsPass());*/
-  //passManager.addPass(createConvertVulkanLaunchFuncToVulkanCallsPass());
+  passManager.addPass(createReconcileUnrealizedCastsPass());
+  passManager.addPass(createConvertVulkanLaunchFuncToVulkanCallsPass());
 
   auto res= passManager.run(module);//FIXME
-  module.dump();
+  //module.dump();
   return res;
 }
 

>From f1454e781d7ac7d5913c431657a50a411972e7e1 Mon Sep 17 00:00:00 2001
From: Andrea Faulds <andrea.faulds at amd.com>
Date: Wed, 4 Dec 2024 17:11:18 +0100
Subject: [PATCH 3/3] [mlir][spirv][vector] Support converting
 vector.from_elements to SPIR-V

---
 .../VectorToSPIRV/VectorToSPIRV.cpp           | 31 +++++++++++++++++--
 .../VectorToSPIRV/vector-to-spirv.mlir        | 29 +++++++++++++++++
 2 files changed, 58 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 656b1cb3e99a1d..d3731db1ce55c9 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -220,6 +220,32 @@ struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
   }
 };
 
+struct VectorFromElementsOpConvert final
+    : public OpConversionPattern<vector::FromElementsOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::FromElementsOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type resultType = getTypeConverter()->convertType(op.getType());
+    if (!resultType)
+      return failure();
+    OperandRange elements = op.getElements();
+    if (isa<spirv::ScalarType>(resultType)) {
+      // In the case with a single scalar operand / single-element result,
+      // pass through the scalar.
+      rewriter.replaceOp(op, elements[0]);
+      return success();
+    }
+    // SPIRVTypeConverter rejects vectors with rank > 1, so multi-dimensional
+    // vector.from_elements cases should not need to be handled, only 1d.
+    assert(cast<VectorType>(resultType).getRank() == 1);
+    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, resultType,
+                                                             elements);
+    return success();
+  }
+};
+
 struct VectorInsertOpConvert final
     : public OpConversionPattern<vector::InsertOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -952,8 +978,9 @@ void mlir::populateVectorToSPIRVPatterns(
       VectorBitcastConvert, VectorBroadcastConvert,
       VectorExtractElementOpConvert, VectorExtractOpConvert,
       VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
-      VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
-      VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
+      VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
+      VectorInsertElementOpConvert, VectorInsertOpConvert,
+      VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
       VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
       VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
       VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 8796f153c4911b..103148633bf97c 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -217,6 +217,35 @@ func.func @extract_dynamic_cst(%arg0 : vector<4xf32>) -> f32 {
 
 // -----
 
+// CHECK-LABEL: @from_elements_0d
+//  CHECK-SAME: %[[ARG0:.+]]: f32
+//       CHECK:   %[[RETVAL:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
+//       CHECK:   return %[[RETVAL]]
+func.func @from_elements_0d(%arg0 : f32) -> vector<f32> {
+  %0 = vector.from_elements %arg0 : vector<f32>
+  return %0: vector<f32>
+}
+
+// CHECK-LABEL: @from_elements_1x
+//  CHECK-SAME: %[[ARG0:.+]]: f32
+//       CHECK:   %[[RETVAL:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
+//       CHECK:   return %[[RETVAL]]
+func.func @from_elements_1x(%arg0 : f32) -> vector<1xf32> {
+  %0 = vector.from_elements %arg0 : vector<1xf32>
+  return %0: vector<1xf32>
+}
+
+// CHECK-LABEL: @from_elements_3x
+//  CHECK-SAME: %[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32
+//       CHECK:   %[[RETVAL:.+]] = spirv.CompositeConstruct %[[ARG0]], %[[ARG1]], %[[ARG2]] : (f32, f32, f32) -> vector<3xf32>
+//       CHECK:   return %[[RETVAL]]
+func.func @from_elements_3x(%arg0 : f32, %arg1 : f32, %arg2 : f32) -> vector<3xf32> {
+  %0 = vector.from_elements %arg0, %arg1, %arg2 : vector<3xf32>
+  return %0: vector<3xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @insert
 //  CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[S:.*]]: f32
 //       CHECK:   spirv.CompositeInsert %[[S]], %[[V]][2 : i32] : f32 into vector<4xf32>



More information about the Mlir-commits mailing list