[flang-commits] [flang] [flang][cuda] Inline this_grid call for cooperative groups (PR #145796)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Wed Jun 25 15:56:11 PDT 2025


https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/145796

>From d9493814270d08caec643cde50e4740910ed9ecf Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 25 Jun 2025 12:03:50 -0700
Subject: [PATCH 1/4] [flang][cuda] Inline this_grid computation

---
 .../flang/Optimizer/Builder/IntrinsicCall.h   |  1 +
 flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 85 +++++++++++++++++++
 flang/module/cooperative_groups.f90           | 30 +++++++
 flang/test/Lower/CUDA/cuda-cooperative.cuf    | 53 ++++++++++++
 flang/tools/f18/CMakeLists.txt                | 14 +--
 5 files changed, 178 insertions(+), 5 deletions(-)
 create mode 100644 flang/module/cooperative_groups.f90
 create mode 100644 flang/test/Lower/CUDA/cuda-cooperative.cuf

diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
index 17052113859e1..3cf7a4920ed7d 100644
--- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
+++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
@@ -442,6 +442,7 @@ struct IntrinsicLibrary {
                                  llvm::ArrayRef<fir::ExtendedValue>);
   fir::ExtendedValue genTranspose(mlir::Type,
                                   llvm::ArrayRef<fir::ExtendedValue>);
+  mlir::Value genThisGrid(mlir::Type, llvm::ArrayRef<mlir::Value>);
   void genThreadFence(llvm::ArrayRef<fir::ExtendedValue>);
   void genThreadFenceBlock(llvm::ArrayRef<fir::ExtendedValue>);
   void genThreadFenceSystem(llvm::ArrayRef<fir::ExtendedValue>);
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 178b6770d6b53..a00dc9f5d30a2 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -932,6 +932,7 @@ static constexpr IntrinsicHandler handlers[]{
      {{{"count", asAddr}, {"count_rate", asAddr}, {"count_max", asAddr}}},
      /*isElemental=*/false},
     {"tand", &I::genTand},
+    {"this_grid", &I::genThisGrid, {}, /*isElemental=*/false},
     {"threadfence", &I::genThreadFence, {}, /*isElemental=*/false},
     {"threadfence_block", &I::genThreadFenceBlock, {}, /*isElemental=*/false},
     {"threadfence_system", &I::genThreadFenceSystem, {}, /*isElemental=*/false},
@@ -8109,6 +8110,90 @@ mlir::Value IntrinsicLibrary::genTand(mlir::Type resultType,
   return getRuntimeCallGenerator("tan", ftype)(builder, loc, {arg});
 }
 
+// THIS_GRID
+mlir::Value IntrinsicLibrary::genThisGrid(mlir::Type resultType,
+                                          llvm::ArrayRef<mlir::Value> args) {
+  assert(args.size() == 0);
+  auto recTy = mlir::cast<fir::RecordType>(resultType);
+  assert(recTy && "RecordType expepected");
+  mlir::Value res = builder.create<fir::AllocaOp>(loc, resultType);
+  mlir::Type i32Ty = builder.getI32Type();
+
+  mlir::Value threadIdX = builder.create<mlir::NVVM::ThreadIdXOp>(loc, i32Ty);
+  mlir::Value threadIdY = builder.create<mlir::NVVM::ThreadIdYOp>(loc, i32Ty);
+  mlir::Value threadIdZ = builder.create<mlir::NVVM::ThreadIdZOp>(loc, i32Ty);
+
+  mlir::Value blockIdX = builder.create<mlir::NVVM::BlockIdXOp>(loc, i32Ty);
+  mlir::Value blockIdY = builder.create<mlir::NVVM::BlockIdYOp>(loc, i32Ty);
+  mlir::Value blockIdZ = builder.create<mlir::NVVM::BlockIdZOp>(loc, i32Ty);
+
+  mlir::Value blockDimX = builder.create<mlir::NVVM::BlockDimXOp>(loc, i32Ty);
+  mlir::Value blockDimY = builder.create<mlir::NVVM::BlockDimYOp>(loc, i32Ty);
+  mlir::Value blockDimZ = builder.create<mlir::NVVM::BlockDimZOp>(loc, i32Ty);
+  mlir::Value gridDimX = builder.create<mlir::NVVM::GridDimXOp>(loc, i32Ty);
+  mlir::Value gridDimY = builder.create<mlir::NVVM::GridDimYOp>(loc, i32Ty);
+  mlir::Value gridDimZ = builder.create<mlir::NVVM::GridDimZOp>(loc, i32Ty);
+
+  // this_grid.size = ((blockDim.z * gridDim.z) * (blockDim.y * gridDim.y)) *
+  // (blockDim.x * gridDim.x);
+  mlir::Value resZ =
+      builder.create<mlir::arith::MulIOp>(loc, blockDimZ, gridDimZ);
+  mlir::Value resY =
+      builder.create<mlir::arith::MulIOp>(loc, blockDimY, gridDimY);
+  mlir::Value resX =
+      builder.create<mlir::arith::MulIOp>(loc, blockDimX, gridDimX);
+  mlir::Value resZY = builder.create<mlir::arith::MulIOp>(loc, resZ, resY);
+  mlir::Value size = builder.create<mlir::arith::MulIOp>(loc, resZY, resX);
+
+  // tmp = ((blockIdx.z * gridDim.y * gridDim.x) + (blockIdx.y * gridDim.x)) +
+  //   blockIdx.x;
+  // this_group.rank = tmp * ((blockDim.x * blockDim.y) * blockDim.z) +
+  //   ((threadIdx.z * blockDim.y) * blockDim.x) +
+  //   (threadIdx.y * blockDim.x) + threadIdx.x + 1;
+  mlir::Value r1 = builder.create<mlir::arith::MulIOp>(loc, blockIdZ, gridDimY);
+  mlir::Value r2 = builder.create<mlir::arith::MulIOp>(loc, r1, gridDimX);
+  mlir::Value r3 = builder.create<mlir::arith::MulIOp>(loc, blockIdY, gridDimX);
+  mlir::Value r2r3 = builder.create<mlir::arith::AddIOp>(loc, r2, r3);
+  mlir::Value tmp = builder.create<mlir::arith::AddIOp>(loc, r2r3, blockIdX);
+
+  mlir::Value bXbY =
+      builder.create<mlir::arith::MulIOp>(loc, blockDimX, blockDimY);
+  mlir::Value bXbYbZ =
+      builder.create<mlir::arith::MulIOp>(loc, bXbY, blockDimZ);
+  mlir::Value tZbY =
+      builder.create<mlir::arith::MulIOp>(loc, threadIdZ, blockDimY);
+  mlir::Value tZbYbX =
+      builder.create<mlir::arith::MulIOp>(loc, tZbY, blockDimX);
+  mlir::Value tYbX =
+      builder.create<mlir::arith::MulIOp>(loc, threadIdY, blockDimX);
+  mlir::Value rank = builder.create<mlir::arith::MulIOp>(loc, tmp, bXbYbZ);
+  rank = builder.create<mlir::arith::AddIOp>(loc, rank, tZbYbX);
+  rank = builder.create<mlir::arith::AddIOp>(loc, rank, tYbX);
+  rank = builder.create<mlir::arith::AddIOp>(loc, rank, threadIdX);
+  mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1);
+  rank = builder.create<mlir::arith::AddIOp>(loc, rank, one);
+
+  auto sizeFieldName = recTy.getTypeList()[1].first;
+  mlir::Type sizeFieldTy = recTy.getTypeList()[1].second;
+  mlir::Type fieldIndexType = fir::FieldType::get(resultType.getContext());
+  mlir::Value sizeFieldIndex = builder.create<fir::FieldIndexOp>(
+      loc, fieldIndexType, sizeFieldName, recTy,
+      /*typeParams=*/mlir::ValueRange{});
+  mlir::Value sizeCoord = builder.create<fir::CoordinateOp>(
+      loc, builder.getRefType(sizeFieldTy), res, sizeFieldIndex);
+  builder.create<fir::StoreOp>(loc, size, sizeCoord);
+
+  auto rankFieldName = recTy.getTypeList()[2].first;
+  mlir::Type rankFieldTy = recTy.getTypeList()[2].second;
+  mlir::Value rankFieldIndex = builder.create<fir::FieldIndexOp>(
+      loc, fieldIndexType, rankFieldName, recTy,
+      /*typeParams=*/mlir::ValueRange{});
+  mlir::Value rankCoord = builder.create<fir::CoordinateOp>(
+      loc, builder.getRefType(rankFieldTy), res, rankFieldIndex);
+  builder.create<fir::StoreOp>(loc, rank, rankCoord);
+  return res;
+}
+
 // TRAILZ
 mlir::Value IntrinsicLibrary::genTrailz(mlir::Type resultType,
                                         llvm::ArrayRef<mlir::Value> args) {
diff --git a/flang/module/cooperative_groups.f90 b/flang/module/cooperative_groups.f90
new file mode 100644
index 0000000000000..e7d19f1c65b1a
--- /dev/null
+++ b/flang/module/cooperative_groups.f90
@@ -0,0 +1,30 @@
+!===-- module/cudedevice.f90 -----------------------------------------------===!
+!
+! Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+! See https://llvm.org/LICENSE.txt for license information.
+! SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+!
+!===------------------------------------------------------------------------===!
+
+! CUDA Fortran cooperative groups
+
+module cooperative_groups
+
+use, intrinsic :: __fortran_builtins, only: c_devptr => __builtin_c_devptr
+
+implicit none
+
+type :: grid_group
+  type(c_devptr), private :: handle
+  integer(4) :: size
+  integer(4) :: rank
+end type grid_group
+
+interface
+  attributes(device) function this_grid()
+    import
+    type(grid_group) :: this_grid
+  end function
+end interface
+
+end module
diff --git a/flang/test/Lower/CUDA/cuda-cooperative.cuf b/flang/test/Lower/CUDA/cuda-cooperative.cuf
new file mode 100644
index 0000000000000..d3deb8f318664
--- /dev/null
+++ b/flang/test/Lower/CUDA/cuda-cooperative.cuf
@@ -0,0 +1,53 @@
+! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
+
+! Test CUDA Fortran procedures available in cooperative_groups module.
+
+attributes(grid_global) subroutine g1()
+  use cooperative_groups
+  type(grid_group) :: gg
+  gg = this_grid()
+end subroutine
+
+! CHECK: %{{.*}} = fir.alloca !fir.type<_QMcooperative_groupsTgrid_group{_QMcooperative_groupsTgrid_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>
+! CHECK: %[[RES:.*]] = fir.alloca !fir.type<_QMcooperative_groupsTgrid_group{_QMcooperative_groupsTgrid_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>
+! CHECK: %[[THREAD_ID_X:.*]] = nvvm.read.ptx.sreg.tid.x : i32
+! CHECK: %[[THREAD_ID_Y:.*]] = nvvm.read.ptx.sreg.tid.y : i32
+! CHECK: %[[THREAD_ID_Z:.*]] = nvvm.read.ptx.sreg.tid.z : i32
+! CHECK: %[[BLOCK_ID_X:.*]] = nvvm.read.ptx.sreg.ctaid.x : i32
+! CHECK: %[[BLOCK_ID_Y:.*]] = nvvm.read.ptx.sreg.ctaid.y : i32
+! CHECK: %[[BLOCK_ID_Z:.*]] = nvvm.read.ptx.sreg.ctaid.z : i32
+! CHECK: %[[BLOCK_DIM_X:.*]] = nvvm.read.ptx.sreg.ntid.x : i32
+! CHECK: %[[BLOCK_DIM_Y:.*]] = nvvm.read.ptx.sreg.ntid.y : i32
+! CHECK: %[[BLOCK_DIM_Z:.*]] = nvvm.read.ptx.sreg.ntid.z : i32
+! CHECK: %[[GRID_DIM_X:.*]] = nvvm.read.ptx.sreg.nctaid.x : i32
+! CHECK: %[[GRID_DIM_Y:.*]] = nvvm.read.ptx.sreg.nctaid.y : i32
+! CHECK: %[[GRID_DIM_Z:.*]] = nvvm.read.ptx.sreg.nctaid.z : i32
+
+! CHECK: %[[R1:.*]] = arith.muli %[[BLOCK_DIM_Z]], %[[GRID_DIM_Z]] : i32
+! CHECK: %[[R2:.*]] = arith.muli %[[BLOCK_DIM_Y]], %[[GRID_DIM_Y]] : i32
+! CHECK: %[[R3:.*]] = arith.muli %[[BLOCK_DIM_X]], %[[GRID_DIM_X]] : i32
+! CHECK: %[[R4:.*]] = arith.muli %[[R1]], %[[R2]] : i32
+! CHECK: %[[SIZE:.*]] = arith.muli %[[R4]], %[[R3]] : i32
+
+! CHECK: %[[R1:.*]] = arith.muli %[[BLOCK_ID_Z]], %[[GRID_DIM_Y]] : i32
+! CHECK: %[[R2:.*]] = arith.muli %[[R1]], %[[GRID_DIM_X]] : i32
+! CHECK: %[[R3:.*]] = arith.muli %[[BLOCK_ID_Y]], %[[GRID_DIM_X]] : i32
+! CHECK: %[[R4:.*]] = arith.addi %[[R2]], %[[R3]] : i32
+! CHECK: %[[TMP:.*]] = arith.addi %[[R4]], %[[BLOCK_ID_X]] : i32
+
+! CHECK: %[[R1:.*]] = arith.muli %[[BLOCK_DIM_X]], %[[BLOCK_DIM_Y]] : i32
+! CHECK: %[[R2:.*]] = arith.muli %[[R1]], %[[BLOCK_DIM_Z]] : i32
+! CHECK: %[[R3:.*]] = arith.muli %[[THREAD_ID_Z]], %[[BLOCK_DIM_Y]] : i32
+! CHECK: %[[R4:.*]] = arith.muli %[[R3]], %[[BLOCK_DIM_X]] : i32
+! CHECK: %[[R5:.*]] = arith.muli %[[THREAD_ID_Y]], %[[BLOCK_DIM_X]] : i32
+! CHECK: %[[RES0:.*]] = arith.muli %[[TMP]], %[[R2]] : i32
+! CHECK: %[[RES1:.*]] = arith.addi %[[RES0]], %[[R4]] : i32
+! CHECK: %[[RES2:.*]] = arith.addi %[[RES1]], %[[R5]] : i32
+! CHECK: %[[RES3:.*]] = arith.addi %[[RES2]], %[[THREAD_ID_X]] : i32
+! CHECK: %[[ONE:.*]] = arith.constant 1 : i32
+! CHECK: %[[RANK:.*]] = arith.addi %[[RES3]], %[[ONE]] : i32
+! CHECK: %[[COORD_SIZE:.*]] = fir.coordinate_of %[[RES]], size : (!fir.ref<!fir.type<_QMcooperative_groupsTgrid_group{_QMcooperative_groupsTgrid_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>>) -> !fir.ref<i32>
+! CHECK: fir.store %[[SIZE]] to %[[COORD_SIZE]] : !fir.ref<i32>
+! CHECK: %[[COORD_RANK:.*]] = fir.coordinate_of %[[RES]], rank : (!fir.ref<!fir.type<_QMcooperative_groupsTgrid_group{_QMcooperative_groupsTgrid_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>>) -> !fir.ref<i32>
+! CHECK: fir.store %[[RANK]] to %[[COORD_RANK]] : !fir.ref<i32>
+
diff --git a/flang/tools/f18/CMakeLists.txt b/flang/tools/f18/CMakeLists.txt
index fb5510d7163d1..95fe9ad87ac45 100644
--- a/flang/tools/f18/CMakeLists.txt
+++ b/flang/tools/f18/CMakeLists.txt
@@ -15,6 +15,7 @@ set(MODULES
   "mma"
   "__cuda_builtins"
   "__cuda_device"
+  "cooperative_groups"
   "cudadevice"
   "ieee_arithmetic"
   "ieee_exceptions"
@@ -60,12 +61,15 @@ if (NOT CMAKE_CROSSCOMPILING)
     elseif(${filename} STREQUAL "__ppc_intrinsics" OR
            ${filename} STREQUAL "mma")
       set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__ppc_types.mod)
-    elseif(${filename} STREQUAL "__cuda_device")
+    elseif(${filename} STREQUAL "__cuda_device" OR
+           ${filename} STREQUAL "cudadevice" OR
+           ${filename} STREQUAL "cooperative_groups")
       set(opts -fc1 -xcuda)
-      set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__cuda_builtins.mod)
-    elseif(${filename} STREQUAL "cudadevice")
-      set(opts -fc1 -xcuda)
-      set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__cuda_device.mod)
+      if(${filename} STREQUAL "__cuda_device")
+        set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__cuda_builtins.mod)
+      elseif(${filename} STREQUAL "cudadevice")
+        set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__cuda_device.mod)
+      endif()
     else()
       set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__fortran_builtins.mod)
       if(${filename} STREQUAL "iso_fortran_env")

>From 6b2a2425a9f3b89e0f327a71a5d46e698cccbff8 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 25 Jun 2025 15:20:38 -0700
Subject: [PATCH 2/4] Fix header

---
 flang/module/cooperative_groups.f90 | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/flang/module/cooperative_groups.f90 b/flang/module/cooperative_groups.f90
index e7d19f1c65b1a..935e41fc56c1a 100644
--- a/flang/module/cooperative_groups.f90
+++ b/flang/module/cooperative_groups.f90
@@ -1,4 +1,4 @@
-!===-- module/cudedevice.f90 -----------------------------------------------===!
+!===-- module/cooperative_groups.f90 ---------------------------------------===!
 !
 ! Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 ! See https://llvm.org/LICENSE.txt for license information.

>From 0b1e0dd1101e7001c3acc7c4cdb716303f416d83 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 25 Jun 2025 15:35:20 -0700
Subject: [PATCH 3/4] Add missing dep

---
 flang/tools/f18/CMakeLists.txt | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/flang/tools/f18/CMakeLists.txt b/flang/tools/f18/CMakeLists.txt
index 95fe9ad87ac45..546b6acaaf91d 100644
--- a/flang/tools/f18/CMakeLists.txt
+++ b/flang/tools/f18/CMakeLists.txt
@@ -69,6 +69,8 @@ if (NOT CMAKE_CROSSCOMPILING)
         set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__cuda_builtins.mod)
       elseif(${filename} STREQUAL "cudadevice")
         set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__cuda_device.mod)
+      elseif(${filename} STREQUAL "cooperative_groups")
+        set(depends ${FLANG_INTRINSIC_MODULES_DIR}/cudadevice.mod)
       endif()
     else()
       set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__fortran_builtins.mod)

>From 3eb70a2548f558bccd12e30cbdbc1225f609e1c6 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 25 Jun 2025 15:55:59 -0700
Subject: [PATCH 4/4] Remove line

---
 flang/test/Lower/CUDA/cuda-cooperative.cuf | 1 -
 1 file changed, 1 deletion(-)

diff --git a/flang/test/Lower/CUDA/cuda-cooperative.cuf b/flang/test/Lower/CUDA/cuda-cooperative.cuf
index d3deb8f318664..54523b18b20db 100644
--- a/flang/test/Lower/CUDA/cuda-cooperative.cuf
+++ b/flang/test/Lower/CUDA/cuda-cooperative.cuf
@@ -50,4 +50,3 @@ end subroutine
 ! CHECK: fir.store %[[SIZE]] to %[[COORD_SIZE]] : !fir.ref<i32>
 ! CHECK: %[[COORD_RANK:.*]] = fir.coordinate_of %[[RES]], rank : (!fir.ref<!fir.type<_QMcooperative_groupsTgrid_group{_QMcooperative_groupsTgrid_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>>) -> !fir.ref<i32>
 ! CHECK: fir.store %[[RANK]] to %[[COORD_RANK]] : !fir.ref<i32>
-



More information about the flang-commits mailing list