[flang-commits] [flang] [flang][cuda] Inline this_warp() calls (PR #146134)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Fri Jun 27 11:33:37 PDT 2025


https://github.com/clementval created https://github.com/llvm/llvm-project/pull/146134

None

>From 997b083577721328754870a9f100278ed7685560 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Fri, 27 Jun 2025 11:32:27 -0700
Subject: [PATCH] [flang][cuda] Inline this_warp() calls

---
 .../flang/Optimizer/Builder/IntrinsicCall.h   |  1 +
 flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 40 +++++++++++++++++++
 flang/module/cooperative_groups.f90           | 13 ++++++
 flang/test/Lower/CUDA/cuda-cooperative.cuf    | 21 ++++++++++
 4 files changed, 75 insertions(+)

diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
index 3cf7a4920ed7d..1e8c1198fb949 100644
--- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
+++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
@@ -443,6 +443,7 @@ struct IntrinsicLibrary {
   fir::ExtendedValue genTranspose(mlir::Type,
                                   llvm::ArrayRef<fir::ExtendedValue>);
   mlir::Value genThisGrid(mlir::Type, llvm::ArrayRef<mlir::Value>);
+  mlir::Value genThisWarp(mlir::Type, llvm::ArrayRef<mlir::Value>);
   void genThreadFence(llvm::ArrayRef<fir::ExtendedValue>);
   void genThreadFenceBlock(llvm::ArrayRef<fir::ExtendedValue>);
   void genThreadFenceSystem(llvm::ArrayRef<fir::ExtendedValue>);
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index a00dc9f5d30a2..80a9389d74a7f 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -933,6 +933,7 @@ static constexpr IntrinsicHandler handlers[]{
      /*isElemental=*/false},
     {"tand", &I::genTand},
     {"this_grid", &I::genThisGrid, {}, /*isElemental=*/false},
+    {"this_warp", &I::genThisWarp, {}, /*isElemental=*/false},
     {"threadfence", &I::genThreadFence, {}, /*isElemental=*/false},
     {"threadfence_block", &I::genThreadFenceBlock, {}, /*isElemental=*/false},
     {"threadfence_system", &I::genThreadFenceSystem, {}, /*isElemental=*/false},
@@ -8194,6 +8195,45 @@ mlir::Value IntrinsicLibrary::genThisGrid(mlir::Type resultType,
   return res;
 }
 
+// THIS_WARP
+mlir::Value IntrinsicLibrary::genThisWarp(mlir::Type resultType,
+                                          llvm::ArrayRef<mlir::Value> args) {
+  assert(args.size() == 0);
+  auto recTy = mlir::cast<fir::RecordType>(resultType);
+  assert(recTy && "RecordType expepected");
+  mlir::Value res = builder.create<fir::AllocaOp>(loc, resultType);
+  mlir::Type i32Ty = builder.getI32Type();
+
+  mlir::Value size = builder.createIntegerConstant(loc, i32Ty, 32);
+  auto sizeFieldName = recTy.getTypeList()[1].first;
+  mlir::Type sizeFieldTy = recTy.getTypeList()[1].second;
+  mlir::Type fieldIndexType = fir::FieldType::get(resultType.getContext());
+  mlir::Value sizeFieldIndex = builder.create<fir::FieldIndexOp>(
+      loc, fieldIndexType, sizeFieldName, recTy,
+      /*typeParams=*/mlir::ValueRange{});
+  mlir::Value sizeCoord = builder.create<fir::CoordinateOp>(
+      loc, builder.getRefType(sizeFieldTy), res, sizeFieldIndex);
+  builder.create<fir::StoreOp>(loc, size, sizeCoord);
+
+  mlir::Value threadIdX = builder.create<mlir::NVVM::ThreadIdXOp>(loc, i32Ty);
+  mlir::Value mask = builder.createIntegerConstant(loc, i32Ty, 31);
+  mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1);
+
+  mlir::Value masked =
+      builder.create<mlir::arith::AndIOp>(loc, threadIdX, mask);
+  mlir::Value rank = builder.create<mlir::arith::AddIOp>(loc, masked, one);
+
+  auto rankFieldName = recTy.getTypeList()[2].first;
+  mlir::Type rankFieldTy = recTy.getTypeList()[2].second;
+  mlir::Value rankFieldIndex = builder.create<fir::FieldIndexOp>(
+      loc, fieldIndexType, rankFieldName, recTy,
+      /*typeParams=*/mlir::ValueRange{});
+  mlir::Value rankCoord = builder.create<fir::CoordinateOp>(
+      loc, builder.getRefType(rankFieldTy), res, rankFieldIndex);
+  builder.create<fir::StoreOp>(loc, rank, rankCoord);
+  return res;
+}
+
 // TRAILZ
 mlir::Value IntrinsicLibrary::genTrailz(mlir::Type resultType,
                                         llvm::ArrayRef<mlir::Value> args) {
diff --git a/flang/module/cooperative_groups.f90 b/flang/module/cooperative_groups.f90
index 935e41fc56c1a..e3c4b53afd8f3 100644
--- a/flang/module/cooperative_groups.f90
+++ b/flang/module/cooperative_groups.f90
@@ -20,6 +20,12 @@ module cooperative_groups
   integer(4) :: rank
 end type grid_group
 
+type :: coalesced_group
+  type(c_devptr), private :: handle
+  integer(4) :: size
+  integer(4) :: rank
+end type coalesced_group
+
 interface
   attributes(device) function this_grid()
     import
@@ -27,4 +33,11 @@ attributes(device) function this_grid()
   end function
 end interface
 
+interface this_warp
+  attributes(device) function this_warp()
+    import
+    type(coalesced_group) :: this_warp
+  end function
+end interface
+
 end module
diff --git a/flang/test/Lower/CUDA/cuda-cooperative.cuf b/flang/test/Lower/CUDA/cuda-cooperative.cuf
index 54523b18b20db..3dc1a5e85f843 100644
--- a/flang/test/Lower/CUDA/cuda-cooperative.cuf
+++ b/flang/test/Lower/CUDA/cuda-cooperative.cuf
@@ -50,3 +50,24 @@ end subroutine
 ! CHECK: fir.store %[[SIZE]] to %[[COORD_SIZE]] : !fir.ref<i32>
 ! CHECK: %[[COORD_RANK:.*]] = fir.coordinate_of %[[RES]], rank : (!fir.ref<!fir.type<_QMcooperative_groupsTgrid_group{_QMcooperative_groupsTgrid_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>>) -> !fir.ref<i32>
 ! CHECK: fir.store %[[RANK]] to %[[COORD_RANK]] : !fir.ref<i32>
+
+attributes(grid_global) subroutine w1()
+  use cooperative_groups
+  type(coalesced_group) :: gg
+  gg = this_warp()
+end subroutine
+
+! CHECK: %[[WARPSIZE:.*]] = fir.alloca i32 {bindc_name = "__builtin_warpsize", uniq_name = "_QM__fortran_builtinsEC__builtin_warpsize"}
+! CHECK: %[[WARPSIZE_DECL:.*]]:2 = hlfir.declare %[[WARPSIZE]] {uniq_name = "_QM__fortran_builtinsEC__builtin_warpsize"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %[[COALESCED_GROUP:.*]] = fir.alloca !fir.type<_QMcooperative_groupsTcoalesced_group{_QMcooperative_groupsTcoalesced_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>
+! CHECK: %[[C32:.*]] = arith.constant 32 : i32
+! CHECK: %[[SIZE_COORD:.*]] = fir.coordinate_of %[[COALESCED_GROUP]], size : (!fir.ref<!fir.type<_QMcooperative_groupsTcoalesced_group{_QMcooperative_groupsTcoalesced_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>>) -> !fir.ref<i32>
+! CHECK: fir.store %[[C32]] to %[[SIZE_COORD]] : !fir.ref<i32>
+
+! CHECK: %[[THREAD_ID:.*]] = nvvm.read.ptx.sreg.tid.x : i32
+! CHECK: %[[C31:.*]] = arith.constant 31 : i32
+! CHECK: %[[C1:.*]] = arith.constant 1 : i32
+! CHECK: %[[AND:.*]] = arith.andi %[[THREAD_ID]], %[[C31]] : i32
+! CHECK: %[[RANK:.*]] = arith.addi %[[AND]], %[[C1]] : i32
+! CHECK: %[[RANK_COORD:.*]] = fir.coordinate_of %{{.*}}, rank : (!fir.ref<!fir.type<_QMcooperative_groupsTcoalesced_group{_QMcooperative_groupsTcoalesced_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>>) -> !fir.ref<i32>
+! CHECK: fir.store %[[RANK]] to %[[RANK_COORD]] : !fir.ref<i32>



More information about the flang-commits mailing list