[flang-commits] [flang] [flang][cuda] Additional update to ExternalNameConversion (PR #119276)

via flang-commits flang-commits at lists.llvm.org
Mon Dec 9 14:02:45 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

<details>
<summary>Changes</summary>

Make the pass working on gpu.func and update correctly launch operation. 

---
Full diff: https://github.com/llvm/llvm-project/pull/119276.diff


2 Files Affected:

- (modified) flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp (+32-25) 
- (modified) flang/test/Fir/CUDA/cuda-extranal-mangling.mlir (+14-2) 


``````````diff
diff --git a/flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp b/flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp
index cfd90ff723793b..eaa40a35e38609 100644
--- a/flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp
@@ -60,23 +60,30 @@ void ExternalNameConversionPass::runOnOperation() {
 
   llvm::DenseMap<mlir::StringAttr, mlir::FlatSymbolRefAttr> remappings;
 
+  auto processFctOrGlobal = [&](mlir::Operation &funcOrGlobal) {
+    auto symName = funcOrGlobal.getAttrOfType<mlir::StringAttr>(
+              mlir::SymbolTable::getSymbolAttrName());
+    auto deconstructedName = fir::NameUniquer::deconstruct(symName);
+    if (fir::NameUniquer::isExternalFacingUniquedName(deconstructedName)) {
+      auto newName =
+          mangleExternalName(deconstructedName, appendUnderscoreOpt);
+      auto newAttr = mlir::StringAttr::get(context, newName);
+      mlir::SymbolTable::setSymbolName(&funcOrGlobal, newAttr);
+      auto newSymRef = mlir::FlatSymbolRefAttr::get(newAttr);
+      remappings.try_emplace(symName, newSymRef);
+      if (llvm::isa<mlir::func::FuncOp>(funcOrGlobal))
+        funcOrGlobal.setAttr(fir::getInternalFuncNameAttrName(), symName);
+    }
+  };
+
   auto renameFuncOrGlobalInModule = [&](mlir::Operation *module) {
-    for (auto &funcOrGlobal : module->getRegion(0).front()) {
-      if (llvm::isa<mlir::func::FuncOp>(funcOrGlobal) ||
-          llvm::isa<fir::GlobalOp>(funcOrGlobal)) {
-        auto symName = funcOrGlobal.getAttrOfType<mlir::StringAttr>(
-            mlir::SymbolTable::getSymbolAttrName());
-        auto deconstructedName = fir::NameUniquer::deconstruct(symName);
-        if (fir::NameUniquer::isExternalFacingUniquedName(deconstructedName)) {
-          auto newName =
-              mangleExternalName(deconstructedName, appendUnderscoreOpt);
-          auto newAttr = mlir::StringAttr::get(context, newName);
-          mlir::SymbolTable::setSymbolName(&funcOrGlobal, newAttr);
-          auto newSymRef = mlir::FlatSymbolRefAttr::get(newAttr);
-          remappings.try_emplace(symName, newSymRef);
-          if (llvm::isa<mlir::func::FuncOp>(funcOrGlobal))
-            funcOrGlobal.setAttr(fir::getInternalFuncNameAttrName(), symName);
-        }
+    for (auto &op : module->getRegion(0).front()) {
+      if (mlir::isa<mlir::func::FuncOp, fir::GlobalOp>(op)) {
+        processFctOrGlobal(op);
+      } else if (auto gpuMod = mlir::dyn_cast<mlir::gpu::GPUModuleOp>(op)) {
+        for (auto &gpuOp : gpuMod.getBodyRegion().front())
+          if (mlir::isa<mlir::func::FuncOp, fir::GlobalOp, mlir::gpu::GPUFuncOp>(gpuOp))
+            processFctOrGlobal(gpuOp);
       }
     }
   };
@@ -85,11 +92,6 @@ void ExternalNameConversionPass::runOnOperation() {
   // globals.
   renameFuncOrGlobalInModule(op);
 
-  // Do the same in GPU modules.
-  if (auto mod = mlir::dyn_cast_or_null<mlir::ModuleOp>(*op))
-    for (auto gpuMod : mod.getOps<mlir::gpu::GPUModuleOp>())
-      renameFuncOrGlobalInModule(gpuMod);
-
   if (remappings.empty())
     return;
 
@@ -97,11 +99,16 @@ void ExternalNameConversionPass::runOnOperation() {
   op.walk([&remappings](mlir::Operation *nestedOp) {
     llvm::SmallVector<std::pair<mlir::StringAttr, mlir::SymbolRefAttr>> updates;
     for (const mlir::NamedAttribute &attr : nestedOp->getAttrDictionary())
-      if (auto symRef = llvm::dyn_cast<mlir::SymbolRefAttr>(attr.getValue()))
-        if (auto remap = remappings.find(symRef.getRootReference());
-            remap != remappings.end())
+      if (auto symRef = llvm::dyn_cast<mlir::SymbolRefAttr>(attr.getValue())) {
+        if (auto remap = remappings.find(symRef.getLeafReference());
+            remap != remappings.end()) {
+          mlir::SymbolRefAttr symAttr = mlir::FlatSymbolRefAttr(remap->second);
+          if (mlir::isa<mlir::gpu::LaunchFuncOp>(nestedOp))
+            symAttr = mlir::SymbolRefAttr::get(symRef.getRootReference(), {mlir::FlatSymbolRefAttr(remap->second)});
           updates.emplace_back(std::pair<mlir::StringAttr, mlir::SymbolRefAttr>{
-              attr.getName(), mlir::SymbolRefAttr(remap->second)});
+              attr.getName(), symAttr});
+        }
+      }
     for (auto update : updates)
       nestedOp->setAttr(update.first, update.second);
   });
diff --git a/flang/test/Fir/CUDA/cuda-extranal-mangling.mlir b/flang/test/Fir/CUDA/cuda-extranal-mangling.mlir
index 551a89a7018c28..cd028a201e6fa9 100644
--- a/flang/test/Fir/CUDA/cuda-extranal-mangling.mlir
+++ b/flang/test/Fir/CUDA/cuda-extranal-mangling.mlir
@@ -1,13 +1,25 @@
 // RUN: fir-opt --split-input-file --external-name-interop %s | FileCheck %s
 
+module @mod attributes {gpu.container_module} {
+
 gpu.module @cuda_device_mod {
-  gpu.func @_QPfoo() {
+  gpu.func @_QPfoo() kernel {
     fir.call @_QPthreadfence() fastmath<contract> : () -> ()
     gpu.return
   }
   func.func private @_QPthreadfence() attributes {cuf.proc_attr = #cuf.cuda_proc<device>}
 }
 
-// CHECK-LABEL: gpu.func @_QPfoo
+func.func @test() -> () {
+  %0 = llvm.mlir.constant(0 : i64) : i64
+  %1 = llvm.mlir.constant(0 : i32) : i32
+  gpu.launch_func  @cuda_device_mod::@_QPfoo blocks in (%0, %0, %0) threads in (%0, %0, %0) : i64 dynamic_shared_memory_size %1
+  return
+}
+
+// CHECK-LABEL: gpu.func @foo_()
 // CHECK: fir.call @threadfence_()
 // CHECK: func.func private @threadfence_()
+// CHECK: gpu.launch_func  @cuda_device_mod::@foo_ 
+
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/119276


More information about the flang-commits mailing list