[flang-commits] [flang] [flang][cuda] Lower CUDA shared variable with cuf.shared_memory op (PR #131399)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Fri Mar 14 15:44:37 PDT 2025


https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/131399

>From 12a2b801ea18b3c95da119aef31c9562e431c441 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Fri, 14 Mar 2025 14:12:42 -0700
Subject: [PATCH 1/3] Add missing builder

---
 flang/include/flang/Optimizer/Transforms/Passes.td | 13 +++++++++++++
 1 file changed, 13 insertions(+)

diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index e5c17cf7d8881..fbab435887b8a 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -453,6 +453,19 @@ def CUFGPUToLLVMConversion : Pass<"cuf-gpu-convert-to-llvm", "mlir::ModuleOp"> {
   ];
 }
 
+def CUFComputeSharedMemoryOffsetsAndSize
+    : Pass<"cuf-compute-shared-memory", "mlir::ModuleOp"> {
+  let summary = "Create the shared memory global variable and set offsets";
+
+  let description = [{
+    Compute the size and alignment of the shared memory global and materialize
+    it. Compute the offset of each cuf.shared_memory operation according to
+    the global and set it.
+  }];
+
+  let dependentDialects = ["fir::FIROpsDialect"];
+}
+
 def SetRuntimeCallAttributes
     : Pass<"set-runtime-call-attrs", "mlir::func::FuncOp"> {
   let summary = "Set Fortran runtime fir.call attributes targeting LLVM IR";

>From 807195f5fe63b207612c5043992b130b4379af15 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Fri, 14 Mar 2025 14:13:27 -0700
Subject: [PATCH 2/3] Remove code

---
 flang/include/flang/Optimizer/Transforms/Passes.td | 13 -------------
 1 file changed, 13 deletions(-)

diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index fbab435887b8a..e5c17cf7d8881 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -453,19 +453,6 @@ def CUFGPUToLLVMConversion : Pass<"cuf-gpu-convert-to-llvm", "mlir::ModuleOp"> {
   ];
 }
 
-def CUFComputeSharedMemoryOffsetsAndSize
-    : Pass<"cuf-compute-shared-memory", "mlir::ModuleOp"> {
-  let summary = "Create the shared memory global variable and set offsets";
-
-  let description = [{
-    Compute the size and alignment of the shared memory global and materialize
-    it. Compute the offset of each cuf.shared_memory operation according to
-    the global and set it.
-  }];
-
-  let dependentDialects = ["fir::FIROpsDialect"];
-}
-
 def SetRuntimeCallAttributes
     : Pass<"set-runtime-call-attrs", "mlir::func::FuncOp"> {
   let summary = "Set Fortran runtime fir.call attributes targeting LLVM IR";

>From abadac9d1e34eeb4e4a894b8a50f75e6cbacf5a6 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Fri, 14 Mar 2025 14:47:31 -0700
Subject: [PATCH 3/3] [flang][cuda] Lower CUDA shared variable with
 cuf.shared_memory operation

---
 flang/lib/Lower/ConvertVariable.cpp   | 14 ++++++++++----
 flang/test/Lower/CUDA/cuda-shared.cuf | 11 +++++++++++
 2 files changed, 21 insertions(+), 4 deletions(-)
 create mode 100644 flang/test/Lower/CUDA/cuda-shared.cuf

diff --git a/flang/lib/Lower/ConvertVariable.cpp b/flang/lib/Lower/ConvertVariable.cpp
index ab5e6346f8d54..05256fec67241 100644
--- a/flang/lib/Lower/ConvertVariable.cpp
+++ b/flang/lib/Lower/ConvertVariable.cpp
@@ -738,9 +738,11 @@ static mlir::Value createNewLocal(Fortran::lower::AbstractConverter &converter,
     auto idxTy = builder.getIndexType();
     for (mlir::Value sh : elidedShape)
       indices.push_back(builder.createConvert(loc, idxTy, sh));
-    mlir::Value alloc = builder.create<cuf::AllocOp>(
-        loc, ty, nm, symNm, dataAttr, lenParams, indices);
-    return alloc;
+    if (dataAttr.getValue() == cuf::DataAttribute::Shared)
+      return builder.create<cuf::SharedMemoryOp>(loc, ty, nm, symNm, lenParams,
+                                                 indices);
+    return builder.create<cuf::AllocOp>(loc, ty, nm, symNm, dataAttr, lenParams,
+                                        indices);
   }
 
   // Let the builder do all the heavy lifting.
@@ -1032,12 +1034,16 @@ static void instantiateLocal(Fortran::lower::AbstractConverter &converter,
                                                symMap);
   if (Fortran::semantics::NeedCUDAAlloc(var.getSymbol())) {
     auto *builder = &converter.getFirOpBuilder();
+    cuf::DataAttributeAttr dataAttr =
+        Fortran::lower::translateSymbolCUFDataAttribute(builder->getContext(),
+                                                        var.getSymbol());
     mlir::Location loc = converter.getCurrentLocation();
     fir::ExtendedValue exv =
         converter.getSymbolExtendedValue(var.getSymbol(), &symMap);
     auto *sym = &var.getSymbol();
     const Fortran::semantics::Scope &owner = sym->owner();
-    if (owner.kind() != Fortran::semantics::Scope::Kind::MainProgram) {
+    if (owner.kind() != Fortran::semantics::Scope::Kind::MainProgram &&
+        dataAttr.getValue() != cuf::DataAttribute::Shared) {
       converter.getFctCtx().attachCleanup([builder, loc, exv, sym]() {
         cuf::DataAttributeAttr dataAttr =
             Fortran::lower::translateSymbolCUFDataAttribute(
diff --git a/flang/test/Lower/CUDA/cuda-shared.cuf b/flang/test/Lower/CUDA/cuda-shared.cuf
new file mode 100644
index 0000000000000..0bacc4ec0b71e
--- /dev/null
+++ b/flang/test/Lower/CUDA/cuda-shared.cuf
@@ -0,0 +1,11 @@
+! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
+
+attributes(global) subroutine sharedmem()
+  real, shared :: s(32)
+  integer :: t
+  t = threadIdx%x
+  s(t) = t
+end subroutine
+
+! CHECK-LABEL: func.func @_QPsharedmem() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
+! CHECK: %{{.*}} = cuf.shared_memory !fir.array<32xf32> {bindc_name = "s", uniq_name = "_QFsharedmemEs"} -> !fir.ref<!fir.array<32xf32>>



More information about the flang-commits mailing list