[flang-commits] [flang] [flang][cuda] Inline this_thread_block() calls (PR #146144)
via flang-commits
flang-commits at lists.llvm.org
Fri Jun 27 12:31:54 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: Valentin Clement (バレンタイン クレメン) (clementval)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/146144.diff
4 Files Affected:
- (modified) flang/include/flang/Optimizer/Builder/IntrinsicCall.h (+1)
- (modified) flang/lib/Optimizer/Builder/IntrinsicCall.cpp (+55)
- (modified) flang/module/cooperative_groups.f90 (+13)
- (modified) flang/test/Lower/CUDA/cuda-cooperative.cuf (+27)
``````````diff
diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
index 3cf7a4920ed7d..0917ff46b4083 100644
--- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
+++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
@@ -443,6 +443,7 @@ struct IntrinsicLibrary {
fir::ExtendedValue genTranspose(mlir::Type,
llvm::ArrayRef<fir::ExtendedValue>);
mlir::Value genThisGrid(mlir::Type, llvm::ArrayRef<mlir::Value>);
+ mlir::Value genThisThreadBlock(mlir::Type, llvm::ArrayRef<mlir::Value>);
void genThreadFence(llvm::ArrayRef<fir::ExtendedValue>);
void genThreadFenceBlock(llvm::ArrayRef<fir::ExtendedValue>);
void genThreadFenceSystem(llvm::ArrayRef<fir::ExtendedValue>);
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index a00dc9f5d30a2..e639447b2882c 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -933,6 +933,7 @@ static constexpr IntrinsicHandler handlers[]{
/*isElemental=*/false},
{"tand", &I::genTand},
{"this_grid", &I::genThisGrid, {}, /*isElemental=*/false},
+ {"this_thread_block", &I::genThisThreadBlock, {}, /*isElemental=*/false},
{"threadfence", &I::genThreadFence, {}, /*isElemental=*/false},
{"threadfence_block", &I::genThreadFenceBlock, {}, /*isElemental=*/false},
{"threadfence_system", &I::genThreadFenceSystem, {}, /*isElemental=*/false},
@@ -8194,6 +8195,60 @@ mlir::Value IntrinsicLibrary::genThisGrid(mlir::Type resultType,
return res;
}
+// THIS_THREAD_BLOCK
+mlir::Value
+IntrinsicLibrary::genThisThreadBlock(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ assert(args.size() == 0);
+ auto recTy = mlir::cast<fir::RecordType>(resultType);
+ assert(recTy && "RecordType expepected");
+ mlir::Value res = builder.create<fir::AllocaOp>(loc, resultType);
+ mlir::Type i32Ty = builder.getI32Type();
+
+ // this_thread_block%size = blockDim.z * blockDim.y * blockDim.x;
+ mlir::Value blockDimX = builder.create<mlir::NVVM::BlockDimXOp>(loc, i32Ty);
+ mlir::Value blockDimY = builder.create<mlir::NVVM::BlockDimYOp>(loc, i32Ty);
+ mlir::Value blockDimZ = builder.create<mlir::NVVM::BlockDimZOp>(loc, i32Ty);
+ mlir::Value size =
+ builder.create<mlir::arith::MulIOp>(loc, blockDimZ, blockDimY);
+ size = builder.create<mlir::arith::MulIOp>(loc, size, blockDimX);
+
+ // this_thread_block%rank = ((threadIdx.z * blockDim.y) * blockDim.x) +
+ // (threadIdx.y * blockDim.x) + threadIdx.x + 1;
+ mlir::Value threadIdX = builder.create<mlir::NVVM::ThreadIdXOp>(loc, i32Ty);
+ mlir::Value threadIdY = builder.create<mlir::NVVM::ThreadIdYOp>(loc, i32Ty);
+ mlir::Value threadIdZ = builder.create<mlir::NVVM::ThreadIdZOp>(loc, i32Ty);
+ mlir::Value r1 =
+ builder.create<mlir::arith::MulIOp>(loc, threadIdZ, blockDimY);
+ mlir::Value r2 = builder.create<mlir::arith::MulIOp>(loc, r1, blockDimX);
+ mlir::Value r3 =
+ builder.create<mlir::arith::MulIOp>(loc, threadIdY, blockDimX);
+ mlir::Value r2r3 = builder.create<mlir::arith::AddIOp>(loc, r2, r3);
+ mlir::Value rank = builder.create<mlir::arith::AddIOp>(loc, r2r3, threadIdX);
+ mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1);
+ rank = builder.create<mlir::arith::AddIOp>(loc, rank, one);
+
+ auto sizeFieldName = recTy.getTypeList()[1].first;
+ mlir::Type sizeFieldTy = recTy.getTypeList()[1].second;
+ mlir::Type fieldIndexType = fir::FieldType::get(resultType.getContext());
+ mlir::Value sizeFieldIndex = builder.create<fir::FieldIndexOp>(
+ loc, fieldIndexType, sizeFieldName, recTy,
+ /*typeParams=*/mlir::ValueRange{});
+ mlir::Value sizeCoord = builder.create<fir::CoordinateOp>(
+ loc, builder.getRefType(sizeFieldTy), res, sizeFieldIndex);
+ builder.create<fir::StoreOp>(loc, size, sizeCoord);
+
+ auto rankFieldName = recTy.getTypeList()[2].first;
+ mlir::Type rankFieldTy = recTy.getTypeList()[2].second;
+ mlir::Value rankFieldIndex = builder.create<fir::FieldIndexOp>(
+ loc, fieldIndexType, rankFieldName, recTy,
+ /*typeParams=*/mlir::ValueRange{});
+ mlir::Value rankCoord = builder.create<fir::CoordinateOp>(
+ loc, builder.getRefType(rankFieldTy), res, rankFieldIndex);
+ builder.create<fir::StoreOp>(loc, rank, rankCoord);
+ return res;
+}
+
// TRAILZ
mlir::Value IntrinsicLibrary::genTrailz(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
diff --git a/flang/module/cooperative_groups.f90 b/flang/module/cooperative_groups.f90
index 935e41fc56c1a..c7d7f47def4dc 100644
--- a/flang/module/cooperative_groups.f90
+++ b/flang/module/cooperative_groups.f90
@@ -20,6 +20,12 @@ module cooperative_groups
integer(4) :: rank
end type grid_group
+type :: thread_group
+ type(c_devptr), private :: handle
+ integer(4) :: size
+ integer(4) :: rank
+end type thread_group
+
interface
attributes(device) function this_grid()
import
@@ -27,4 +33,11 @@ attributes(device) function this_grid()
end function
end interface
+interface
+ attributes(device) function this_thread_block()
+ import
+ type(thread_group) :: this_thread_block
+ end function
+end interface
+
end module
diff --git a/flang/test/Lower/CUDA/cuda-cooperative.cuf b/flang/test/Lower/CUDA/cuda-cooperative.cuf
index 54523b18b20db..82f66b75c3d08 100644
--- a/flang/test/Lower/CUDA/cuda-cooperative.cuf
+++ b/flang/test/Lower/CUDA/cuda-cooperative.cuf
@@ -50,3 +50,30 @@ end subroutine
! CHECK: fir.store %[[SIZE]] to %[[COORD_SIZE]] : !fir.ref<i32>
! CHECK: %[[COORD_RANK:.*]] = fir.coordinate_of %[[RES]], rank : (!fir.ref<!fir.type<_QMcooperative_groupsTgrid_group{_QMcooperative_groupsTgrid_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>>) -> !fir.ref<i32>
! CHECK: fir.store %[[RANK]] to %[[COORD_RANK]] : !fir.ref<i32>
+
+attributes(grid_global) subroutine t1()
+ use cooperative_groups
+ type(thread_group) :: gg
+ gg = this_thread_block()
+end subroutine
+! CHECK: %{{.*}} = fir.alloca !fir.type<_QMcooperative_groupsTthread_group{_QMcooperative_groupsTthread_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>
+! CHECK: %[[THREAD_GROUP:.*]] = fir.alloca !fir.type<_QMcooperative_groupsTthread_group{_QMcooperative_groupsTthread_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>
+! CHECK: %[[NTID_X:.*]] = nvvm.read.ptx.sreg.ntid.x : i32
+! CHECK: %[[NTID_Y:.*]] = nvvm.read.ptx.sreg.ntid.y : i32
+! CHECK: %[[NTID_Z:.*]] = nvvm.read.ptx.sreg.ntid.z : i32
+! CHECK: %[[SIZE_ZY:.*]] = arith.muli %[[NTID_Z]], %[[NTID_Y]] : i32
+! CHECK: %[[SIZE:.*]] = arith.muli %[[SIZE_ZY]], %[[NTID_X]] : i32
+! CHECK: %[[TID_X:.*]] = nvvm.read.ptx.sreg.tid.x : i32
+! CHECK: %[[TID_Y:.*]] = nvvm.read.ptx.sreg.tid.y : i32
+! CHECK: %[[TID_Z:.*]] = nvvm.read.ptx.sreg.tid.z : i32
+! CHECK: %[[RANK_ZY:.*]] = arith.muli %[[TID_Z]], %[[NTID_Y]] : i32
+! CHECK: %[[RANK_ZYX:.*]] = arith.muli %[[RANK_ZY]], %[[NTID_X]] : i32
+! CHECK: %[[RANK_YX:.*]] = arith.muli %[[TID_Y]], %[[NTID_X]] : i32
+! CHECK: %[[RANK_SUM1:.*]] = arith.addi %[[RANK_ZYX]], %[[RANK_YX]] : i32
+! CHECK: %[[RANK_SUM2:.*]] = arith.addi %[[RANK_SUM1]], %[[TID_X]] : i32
+! CHECK: %[[C1:.*]] = arith.constant 1 : i32
+! CHECK: %[[RANK:.*]] = arith.addi %[[RANK_SUM2]], %[[C1]] : i32
+! CHECK: %[[SIZE_COORD:.*]] = fir.coordinate_of %[[THREAD_GROUP]], size : (!fir.ref<!fir.type<_QMcooperative_groupsTthread_group{_QMcooperative_groupsTthread_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>>) -> !fir.ref<i32>
+! CHECK: fir.store %[[SIZE]] to %[[SIZE_COORD]] : !fir.ref<i32>
+! CHECK: %[[RANK_COORD:.*]] = fir.coordinate_of %[[THREAD_GROUP]], rank : (!fir.ref<!fir.type<_QMcooperative_groupsTthread_group{_QMcooperative_groupsTthread_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>>) -> !fir.ref<i32>
+! CHECK: fir.store %[[RANK]] to %[[RANK_COORD]] : !fir.ref<i32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/146144
More information about the flang-commits
mailing list