[Mlir-commits] [llvm] [mlir] [mlir][spirv] Integrate `convert-to-spirv` into `mlir-vulkan-runner` (PR #106082)

Angel Zhang llvmlistbot at llvm.org
Mon Aug 26 07:29:39 PDT 2024


https://github.com/angelz913 created https://github.com/llvm/llvm-project/pull/106082

This PR adds a new option for `convert-to-spirv` pass to clone and convert only GPU kernel modules for integration testing. The PR also replaces the `gpu-to-spirv` pass with the `convert-to-spirv` pass (with the new option)  in `mlir-vulkan-runner`.

>From 7e2d3046a60a849b60bd6a087336e8c76715ba1e Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Sat, 24 Aug 2024 03:10:36 +0000
Subject: [PATCH] [mlir][spirv] Integrate convert-to-spirv into
 mlir-vulkan-runner

---
 mlir/include/mlir/Conversion/Passes.td        |  5 +-
 .../Conversion/ConvertToSPIRV/CMakeLists.txt  |  1 +
 .../ConvertToSPIRV/ConvertToSPIRVPass.cpp     | 81 +++++++++++++------
 .../mlir-vulkan-runner/mlir-vulkan-runner.cpp |  5 +-
 .../llvm-project-overlay/mlir/BUILD.bazel     |  1 +
 5 files changed, 66 insertions(+), 27 deletions(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 7bde9e490e4f4e..244827ade66be6 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -50,7 +50,10 @@ def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
     "Run function signature conversion to convert vector types">,
     Option<"runVectorUnrolling", "run-vector-unrolling", "bool",
     /*default=*/"true",
-    "Run vector unrolling to convert vector types in function bodies">
+    "Run vector unrolling to convert vector types in function bodies">,
+    Option<"runOnGPUModules", "run-on-gpu-modules", "bool",
+    /*default=*/"false",
+    "Clone and convert only the GPU modules for integration testing">
   ];
 }
 
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
index 863ef9603da385..124a4c453e75c5 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/ConvertToSPIRV/CMakeLists.txt
@@ -15,6 +15,7 @@ add_mlir_conversion_library(MLIRConvertToSPIRVPass
   MLIRArithToSPIRV
   MLIRArithTransforms
   MLIRFuncToSPIRV
+  MLIRGPUDialect
   MLIRGPUToSPIRV
   MLIRIndexToSPIRV
   MLIRIR
diff --git a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
index 9e57b923ea6894..f624999decac48 100644
--- a/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
@@ -40,6 +41,35 @@ using namespace mlir;
 
 namespace {
 
+/// Map memRef memory space to SPIR-V storage class.
+void mapToMemRef(Operation *op, spirv::TargetEnvAttr &targetAttr) {
+  spirv::TargetEnv targetEnv(targetAttr);
+  bool targetEnvSupportsKernelCapability =
+      targetEnv.allows(spirv::Capability::Kernel);
+  spirv::MemorySpaceToStorageClassMap memorySpaceMap =
+      targetEnvSupportsKernelCapability
+          ? spirv::mapMemorySpaceToOpenCLStorageClass
+          : spirv::mapMemorySpaceToVulkanStorageClass;
+  spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
+  spirv::convertMemRefTypesAndAttrs(op, converter);
+}
+
+/// Populate patterns for each dialect.
+void populateConvertToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
+                                    ScfToSPIRVContext &scfToSPIRVContext,
+                                    RewritePatternSet &patterns) {
+  arith::populateCeilFloorDivExpandOpsPatterns(patterns);
+  arith::populateArithToSPIRVPatterns(typeConverter, patterns);
+  populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
+  populateFuncToSPIRVPatterns(typeConverter, patterns);
+  populateGPUToSPIRVPatterns(typeConverter, patterns);
+  index::populateIndexToSPIRVPatterns(typeConverter, patterns);
+  populateMemRefToSPIRVPatterns(typeConverter, patterns);
+  populateVectorToSPIRVPatterns(typeConverter, patterns);
+  populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
+  ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
+}
+
 /// A pass to perform the SPIR-V conversion.
 struct ConvertToSPIRVPass final
     : impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
@@ -64,31 +94,32 @@ struct ConvertToSPIRVPass final
     RewritePatternSet patterns(context);
     ScfToSPIRVContext scfToSPIRVContext;
 
-    // Map MemRef memory space to SPIR-V storage class.
-    spirv::TargetEnv targetEnv(targetAttr);
-    bool targetEnvSupportsKernelCapability =
-        targetEnv.allows(spirv::Capability::Kernel);
-    spirv::MemorySpaceToStorageClassMap memorySpaceMap =
-        targetEnvSupportsKernelCapability
-            ? spirv::mapMemorySpaceToOpenCLStorageClass
-            : spirv::mapMemorySpaceToVulkanStorageClass;
-    spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
-    spirv::convertMemRefTypesAndAttrs(op, converter);
-
-    // Populate patterns for each dialect.
-    arith::populateCeilFloorDivExpandOpsPatterns(patterns);
-    arith::populateArithToSPIRVPatterns(typeConverter, patterns);
-    populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
-    populateFuncToSPIRVPatterns(typeConverter, patterns);
-    populateGPUToSPIRVPatterns(typeConverter, patterns);
-    index::populateIndexToSPIRVPatterns(typeConverter, patterns);
-    populateMemRefToSPIRVPatterns(typeConverter, patterns);
-    populateVectorToSPIRVPatterns(typeConverter, patterns);
-    populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
-    ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
-
-    if (failed(applyPartialConversion(op, *target, std::move(patterns))))
-      return signalPassFailure();
+    if (runOnGPUModules) {
+      SmallVector<Operation *, 1> gpuModules;
+      OpBuilder builder(context);
+      op->walk([&](gpu::GPUModuleOp gpuModule) {
+        builder.setInsertionPoint(gpuModule);
+        gpuModules.push_back(builder.clone(*gpuModule));
+      });
+      // Run conversion for each module independently as they can have
+      // different TargetEnv attributes.
+      for (Operation *gpuModule : gpuModules) {
+        spirv::TargetEnvAttr targetAttr =
+            spirv::lookupTargetEnvOrDefault(gpuModule);
+        mapToMemRef(gpuModule, targetAttr);
+        populateConvertToSPIRVPatterns(typeConverter, scfToSPIRVContext,
+                                       patterns);
+        if (failed(
+                applyFullConversion(gpuModule, *target, std::move(patterns))))
+          return signalPassFailure();
+      }
+    } else {
+      mapToMemRef(op, targetAttr);
+      populateConvertToSPIRVPatterns(typeConverter, scfToSPIRVContext,
+                                     patterns);
+      if (failed(applyPartialConversion(op, *target, std::move(patterns))))
+        return signalPassFailure();
+    }
   }
 };
 
diff --git a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
index 032f5760361f4b..93ca922c57084f 100644
--- a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
+++ b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
@@ -12,6 +12,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
 #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
@@ -64,7 +65,9 @@ static LogicalResult runMLIRPasses(Operation *op,
   passManager.addPass(createGpuKernelOutliningPass());
   passManager.addPass(memref::createFoldMemRefAliasOpsPass());
 
-  passManager.addPass(createConvertGPUToSPIRVPass(/*mapMemorySpace=*/true));
+  ConvertToSPIRVPassOptions convertToSPIRVOptions{};
+  convertToSPIRVOptions.runOnGPUModules = true;
+  passManager.addPass(createConvertToSPIRVPass(convertToSPIRVOptions));
   OpPassManager &modulePM = passManager.nest<spirv::ModuleOp>();
   modulePM.addPass(spirv::createSPIRVLowerABIAttributesPass());
   modulePM.addPass(spirv::createSPIRVUpdateVCEPass());
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index de069daf603f1e..35b3ef47603d99 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -8401,6 +8401,7 @@ cc_library(
         ":ArithTransforms",
         ":ConversionPassIncGen",
         ":FuncToSPIRV",
+        ":GPUDialect",
         ":GPUToSPIRV",
         ":IR",
         ":IndexToSPIRV",



More information about the Mlir-commits mailing list