[flang-commits] [flang] 47ea854 - [flang] Update target rewrite to support workgroup and private attributions (#164515)

via flang-commits flang-commits at lists.llvm.org
Wed Oct 22 09:48:14 PDT 2025


Author: Valentin Clement (バレンタイン クレメン)
Date: 2025-10-22T09:48:10-07:00
New Revision: 47ea8543e26a823a0543bbdf2ff529ec432c09e2

URL: https://github.com/llvm/llvm-project/commit/47ea8543e26a823a0543bbdf2ff529ec432c09e2
DIFF: https://github.com/llvm/llvm-project/commit/47ea8543e26a823a0543bbdf2ff529ec432c09e2.diff

LOG: [flang] Update target rewrite to support workgroup and private attributions (#164515)

Some operations like the gpu.func have arguments that need to stay in
place while rewriting the signature. This is the case for the workgroup
and private attribution.
Update the target rewrite pass to be aware of that when adding argument
at the end of the function signature. If any trailing arguments are
present, the new argument will be inserted just before them.

Added: 
    

Modified: 
    flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
    flang/test/Fir/CUDA/cuda-target-rewrite.mlir
    flang/tools/fir-opt/fir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index ac285b5d403df..0776346870c72 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -872,6 +872,14 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
       }
     }
 
+    // Count the number of arguments that have to stay in place at the end of
+    // the argument list.
+    unsigned trailingArgs = 0;
+    if constexpr (std::is_same_v<FuncOpTy, mlir::gpu::GPUFuncOp>) {
+      trailingArgs =
+          func.getNumWorkgroupAttributions() + func.getNumPrivateAttributions();
+    }
+
     // Convert return value(s)
     for (auto ty : funcTy.getResults())
       llvm::TypeSwitch<mlir::Type>(ty)
@@ -981,6 +989,16 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
       }
     }
 
+    // Add the argument at the end if the number of trailing arguments is 0,
+    // otherwise insert the argument at the appropriate index.
+    auto addOrInsertArgument = [&](mlir::Type ty, mlir::Location loc) {
+      unsigned inputIndex = func.front().getArguments().size() - trailingArgs;
+      auto newArg = trailingArgs == 0
+                        ? func.front().addArgument(ty, loc)
+                        : func.front().insertArgument(inputIndex, ty, loc);
+      return newArg;
+    };
+
     if (!func.empty()) {
       // If the function has a body, then apply the fixups to the arguments and
       // return ops as required. These fixups are done in place.
@@ -1117,8 +1135,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
           // original arguments. (Boxchar arguments.)
           auto newBufArg =
               func.front().insertArgument(fixup.index, fixupType, loc);
-          auto newLenArg =
-              func.front().addArgument(trailingTys[fixup.second], loc);
+          auto newLenArg = addOrInsertArgument(trailingTys[fixup.second], loc);
           auto boxTy = oldArgTys[fixup.index - offset];
           rewriter->setInsertionPointToStart(&func.front());
           auto box = fir::EmboxCharOp::create(*rewriter, loc, boxTy, newBufArg,
@@ -1133,8 +1150,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
           // appended after all the original arguments.
           auto newProcPointerArg =
               func.front().insertArgument(fixup.index, fixupType, loc);
-          auto newLenArg =
-              func.front().addArgument(trailingTys[fixup.second], loc);
+          auto newLenArg = addOrInsertArgument(trailingTys[fixup.second], loc);
           auto tupleType = oldArgTys[fixup.index - offset];
           rewriter->setInsertionPointToStart(&func.front());
           fir::FirOpBuilder builder(*rewriter, getModule());

diff  --git a/flang/test/Fir/CUDA/cuda-target-rewrite.mlir b/flang/test/Fir/CUDA/cuda-target-rewrite.mlir
index a334934f31723..48fee10f3db97 100644
--- a/flang/test/Fir/CUDA/cuda-target-rewrite.mlir
+++ b/flang/test/Fir/CUDA/cuda-target-rewrite.mlir
@@ -55,3 +55,56 @@ func.func @main(%arg0: complex<f64>) {
 // CHECK-SAME: (%arg0: f64, %arg1: f64) kernel {
 // CHECK: gpu.return
 // CHECK: gpu.launch_func  @testmod::@_QPtest blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) : i64 dynamic_shared_memory_size %{{.*}} args(%{{.*}} : f64, %{{.*}} : f64) {cuf.proc_attr = #cuf.cuda_proc<global>}
+
+// -----
+
+module attributes {gpu.container_module, fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-unknown-linux-gnu"} {
+  gpu.module @testmod {
+    gpu.func @_QMbarPfoo(%arg0: f32, %arg1: !fir.ref<!fir.array<100xf32>>, %arg2: !fir.boxchar<1>) workgroup(%arg3 : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) {
+      %c0 = arith.constant 0 : index
+      memref.store %arg0, %arg3[%c0] : memref<1xf32, #gpu.address_space<workgroup>>
+      gpu.return
+    }
+// CHECK-LABEL: gpu.func @_QMbarPfoo(
+// CHECK-SAME: %{{.*}}: f32, %{{.*}}: !fir.ref<!fir.array<100xf32>>, %[[CHAR:.*]]: !fir.ref<!fir.char<1,?>>, %[[LENGTH:.*]]: i64) workgroup(%[[WORKGROUP:.*]] : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) {
+// CHECK: %{{.*}} = fir.emboxchar %[[CHAR]], %[[LENGTH]] : (!fir.ref<!fir.char<1,?>>, i64) -> !fir.boxchar<1>
+// CHECK: memref.store %{{.*}}, %[[WORKGROUP]][%{{.*}}] : memref<1xf32, #gpu.address_space<workgroup>>
+
+    gpu.func @_QMbarPfoo2(%arg0: f32, %arg1: !fir.ref<!fir.array<100xf32>>, %arg2: !fir.boxchar<1>) workgroup(%arg3 : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}, %arg4 : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) {
+      %c0 = arith.constant 0 : index
+      memref.store %arg0, %arg3[%c0] : memref<1xf32, #gpu.address_space<workgroup>>
+      memref.store %arg0, %arg4[%c0] : memref<1xf32, #gpu.address_space<workgroup>>
+      gpu.return
+    }
+// CHECK-LABEL: gpu.func @_QMbarPfoo2(
+// CHECK-SAME: %{{.*}}: f32, %{{.*}}: !fir.ref<!fir.array<100xf32>>, %[[CHAR:.*]]: !fir.ref<!fir.char<1,?>>, %[[LENGTH:.*]]: i64) workgroup(%[[WG1:.*]] : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}, %[[WG2:.*]] : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) {
+// CHECK: %{{.*}} = fir.emboxchar %[[CHAR]], %[[LENGTH]] : (!fir.ref<!fir.char<1,?>>, i64) -> !fir.boxchar<1>
+// CHECK: memref.store %{{.*}}, %[[WG1]][%{{.*}}] : memref<1xf32, #gpu.address_space<workgroup>>
+// CHECK: memref.store %{{.*}}, %[[WG2]][%{{.*}}] : memref<1xf32, #gpu.address_space<workgroup>>
+
+    gpu.func @_QMbarPprivate(%arg0: f32, %arg1: !fir.boxchar<1>) workgroup(%arg2 : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) private(%arg3 : memref<1xf32, #gpu.address_space<private>> {llvm.align = 16 : i32}) {
+      %c0 = arith.constant 0 : index
+      memref.store %arg0, %arg2[%c0] : memref<1xf32, #gpu.address_space<workgroup>>
+      memref.store %arg0, %arg3[%c0] : memref<1xf32, #gpu.address_space<private>>
+      gpu.return
+    }
+// CHECK-LABEL: gpu.func @_QMbarPprivate(
+// CHECK-SAME: %{{.*}}: f32, %[[CHAR:.*]]: !fir.ref<!fir.char<1,?>>, %[[LENGTH:.*]]: i64) workgroup(%[[WG:.*]] : memref<1xf32, #gpu.address_space<workgroup>> {llvm.align = 16 : i32}) private(%[[PRIVATE:.*]] : memref<1xf32, #gpu.address_space<private>> {llvm.align = 16 : i32}) {
+// CHECK: %{{.*}} = fir.emboxchar %[[CHAR]], %[[LENGTH]] : (!fir.ref<!fir.char<1,?>>, i64) -> !fir.boxchar<1>
+// CHECK: memref.store %{{.*}}, %[[WG]][%{{.*}}] : memref<1xf32, #gpu.address_space<workgroup>>
+// CHECK: memref.store %{{.*}}, %[[PRIVATE]][%{{.*}}] : memref<1xf32, #gpu.address_space<private>>
+    
+    gpu.func @test_with_char_proc(%arg0: f32, %arg1: tuple<() -> (), i64> {fir.char_proc}) workgroup(%arg2 : memref<1xf32, #gpu.address_space<workgroup>>) {
+      %c0 = arith.constant 0 : index
+      memref.store %arg0, %arg2[%c0] : memref<1xf32, #gpu.address_space<workgroup>>
+      gpu.return
+    }
+// CHECK-LABEL: gpu.func @test_with_char_proc(
+// CHECK-SAME: %{{.*}}: f32, %[[CHARPROC:.*]]: () -> () {fir.char_proc}, %[[LENGTH:.*]]: i64) workgroup(%[[WG:.*]] : memref<1xf32, #gpu.address_space<workgroup>>) {
+// CHECK: %{{.*}} = fir.undefined tuple<() -> (), i64>
+// CHECK: %{{.*}} = fir.insert_value %{{.*}}, %[[CHARPROC]], [0 : index] : (tuple<() -> (), i64>, () -> ()) -> tuple<() -> (), i64>
+// CHECK: %{{.*}} = fir.insert_value %{{.*}}, %[[LENGTH]], [1 : index] : (tuple<() -> (), i64>, i64) -> tuple<() -> (), i64>
+// CHECK: memref.store %{{.*}}, %[[WG]][%{{.*}}] : memref<1xf32, #gpu.address_space<workgroup>>
+  }
+}
+

diff  --git a/flang/tools/fir-opt/fir-opt.cpp b/flang/tools/fir-opt/fir-opt.cpp
index 32b0a1dfa5c7a..67d07eee1f4fc 100644
--- a/flang/tools/fir-opt/fir-opt.cpp
+++ b/flang/tools/fir-opt/fir-opt.cpp
@@ -50,6 +50,7 @@ int main(int argc, char **argv) {
 #endif
   DialectRegistry registry;
   fir::support::registerDialects(registry);
+  registry.insert<mlir::memref::MemRefDialect>();
   fir::support::addFIRExtensions(registry);
   return failed(MlirOptMain(argc, argv, "FIR modular optimizer driver\n",
       registry));


        


More information about the flang-commits mailing list