[flang-commits] [flang] [flang][cuda] Add interfaces and lowering for tma_bulk_store (PR #165482)
Valentin Clement バレンタイン クレメン via flang-commits
flang-commits at lists.llvm.org
Tue Oct 28 14:36:45 PDT 2025
https://github.com/clementval created https://github.com/llvm/llvm-project/pull/165482
As defined in https://docs.nvidia.com/hpc-sdk/compilers/cuda-fortran-prog-guide/#load-and-store-functions-using-bulk-tma-operations
>From 757e0bb6737025049256b3a78ef4638f4cf221e7 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Tue, 28 Oct 2025 14:35:44 -0700
Subject: [PATCH] [flang][cuda] Add interfaces and lowering for tma_bulk_store
---
.../flang/Optimizer/Builder/IntrinsicCall.h | 7 ++
flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 114 ++++++++++++++++++
flang/module/cudadevice.f90 | 62 ++++++++++
flang/test/Lower/CUDA/cuda-device-proc.cuf | 92 ++++++++++++++
4 files changed, 275 insertions(+)
diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
index ed0cbd3bdf16b..360a7ee3033da 100644
--- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
+++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
@@ -462,6 +462,13 @@ struct IntrinsicLibrary {
void genTMABulkCommitGroup(llvm::ArrayRef<fir::ExtendedValue>);
void genTMABulkG2S(llvm::ArrayRef<fir::ExtendedValue>);
void genTMABulkS2G(llvm::ArrayRef<fir::ExtendedValue>);
+ void genTMABulkStoreI4(llvm::ArrayRef<fir::ExtendedValue>);
+ void genTMABulkStoreI8(llvm::ArrayRef<fir::ExtendedValue>);
+ void genTMABulkStoreR2(llvm::ArrayRef<fir::ExtendedValue>);
+ void genTMABulkStoreR4(llvm::ArrayRef<fir::ExtendedValue>);
+ void genTMABulkStoreR8(llvm::ArrayRef<fir::ExtendedValue>);
+ void genTMABulkStoreC4(llvm::ArrayRef<fir::ExtendedValue>);
+ void genTMABulkStoreC8(llvm::ArrayRef<fir::ExtendedValue>);
void genTMABulkWaitGroup(llvm::ArrayRef<fir::ExtendedValue>);
mlir::Value genTrailz(mlir::Type, llvm::ArrayRef<mlir::Value>);
fir::ExtendedValue genTransfer(mlir::Type,
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 0d225532f2460..aab7274727628 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -1049,6 +1049,34 @@ static constexpr IntrinsicHandler handlers[]{
&I::genTMABulkS2G,
{{{"src", asAddr}, {"dst", asAddr}, {"nbytes", asValue}}},
/*isElemental=*/false},
+ {"tma_bulk_store_c4",
+ &I::genTMABulkStoreC4,
+ {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
+ /*isElemental=*/false},
+ {"tma_bulk_store_c8",
+ &I::genTMABulkStoreC8,
+ {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
+ /*isElemental=*/false},
+ {"tma_bulk_store_i4",
+ &I::genTMABulkStoreI4,
+ {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
+ /*isElemental=*/false},
+ {"tma_bulk_store_i8",
+ &I::genTMABulkStoreI8,
+ {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
+ /*isElemental=*/false},
+ {"tma_bulk_store_r2",
+ &I::genTMABulkStoreR2,
+ {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
+ /*isElemental=*/false},
+ {"tma_bulk_store_r4",
+ &I::genTMABulkStoreR4,
+ {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
+ /*isElemental=*/false},
+ {"tma_bulk_store_r8",
+ &I::genTMABulkStoreR8,
+ {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
+ /*isElemental=*/false},
{"tma_bulk_wait_group",
&I::genTMABulkWaitGroup,
{{}},
@@ -9289,6 +9317,92 @@ void IntrinsicLibrary::genTMABulkS2G(llvm::ArrayRef<fir::ExtendedValue> args) {
builder, loc, dst, src, fir::getBase(args[2]), {}, {});
}
+static void genTMABulkStore(fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::Value src, mlir::Value dst, mlir::Value count,
+ mlir::Value eleSize) {
+ mlir::Value size = mlir::arith::MulIOp::create(builder, loc, eleSize, count);
+ src = convertPtrToNVVMSpace(builder, loc, src,
+ mlir::NVVM::NVVMMemorySpace::Shared);
+ dst = convertPtrToNVVMSpace(builder, loc, dst,
+ mlir::NVVM::NVVMMemorySpace::Global);
+ mlir::NVVM::CpAsyncBulkSharedCTAToGlobalOp::create(builder, loc, dst, src,
+ size, {}, {});
+ mlir::NVVM::InlinePtxOp::create(builder, loc, mlir::TypeRange{}, {}, {},
+ "cp.async.bulk.commit_group", {});
+ mlir::NVVM::CpAsyncBulkWaitGroupOp::create(builder, loc,
+ builder.getI32IntegerAttr(0), {});
+}
+
+// TMA_BULK_STORE_C4 (CUDA)
+void IntrinsicLibrary::genTMABulkStoreC4(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 3);
+ mlir::Value eleSize =
+ builder.createIntegerConstant(loc, builder.getI32Type(), 8);
+ genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+ fir::getBase(args[2]), eleSize);
+}
+
+// TMA_BULK_STORE_C8 (CUDA)
+void IntrinsicLibrary::genTMABulkStoreC8(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 3);
+ mlir::Value eleSize =
+ builder.createIntegerConstant(loc, builder.getI32Type(), 16);
+ genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+ fir::getBase(args[2]), eleSize);
+}
+
+// TMA_BULK_STORE_I4 (CUDA)
+void IntrinsicLibrary::genTMABulkStoreI4(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 3);
+ mlir::Value eleSize =
+ builder.createIntegerConstant(loc, builder.getI32Type(), 4);
+ genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+ fir::getBase(args[2]), eleSize);
+}
+
+// TMA_BULK_STORE_I8 (CUDA)
+void IntrinsicLibrary::genTMABulkStoreI8(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 3);
+ mlir::Value eleSize =
+ builder.createIntegerConstant(loc, builder.getI32Type(), 8);
+ genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+ fir::getBase(args[2]), eleSize);
+}
+
+// TMA_BULK_STORE_R2 (CUDA)
+void IntrinsicLibrary::genTMABulkStoreR2(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 3);
+ mlir::Value eleSize =
+ builder.createIntegerConstant(loc, builder.getI32Type(), 2);
+ genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+ fir::getBase(args[2]), eleSize);
+}
+
+// TMA_BULK_STORE_R4 (CUDA)
+void IntrinsicLibrary::genTMABulkStoreR4(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 3);
+ mlir::Value eleSize =
+ builder.createIntegerConstant(loc, builder.getI32Type(), 4);
+ genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+ fir::getBase(args[2]), eleSize);
+}
+
+// TMA_BULK_STORE_R8 (CUDA)
+void IntrinsicLibrary::genTMABulkStoreR8(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 3);
+ mlir::Value eleSize =
+ builder.createIntegerConstant(loc, builder.getI32Type(), 8);
+ genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+ fir::getBase(args[2]), eleSize);
+}
+
// TMA_BULK_WAIT_GROUP (CUDA)
void IntrinsicLibrary::genTMABulkWaitGroup(
llvm::ArrayRef<fir::ExtendedValue> args) {
diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90
index ea54c974c9e7c..24a20c46167e8 100644
--- a/flang/module/cudadevice.f90
+++ b/flang/module/cudadevice.f90
@@ -2058,6 +2058,13 @@ attributes(device) subroutine tma_bulk_g2s(barrier, src, dst, nbytes)
end subroutine
end interface
+ ! --------------------
+ ! Bulk Store functions
+ ! --------------------
+
+ ! Generic store, count is in bytes
+ ! --------------------------------
+
interface
attributes(device) subroutine tma_bulk_s2g(src, dst, nbytes)
!dir$ ignore_tkr src, dst
@@ -2067,6 +2074,61 @@ attributes(device) subroutine tma_bulk_s2g(src, dst, nbytes)
end subroutine
end interface
+
+ ! Load specific types, count is in elements
+ ! -----------------------------------------
+
+ interface tma_bulk_store
+ attributes(device) subroutine tma_bulk_store_c4(src, dst, nelems)
+ !dir$ ignore_tkr (r) src, (r) dst
+ complex(4), shared :: src(*)
+ complex(4), device :: dst(*)
+ integer(4), value :: nelems
+ end subroutine
+
+ attributes(device) subroutine tma_bulk_store_c8(src, dst, nelems)
+ !dir$ ignore_tkr (r) src, (r) dst
+ complex(8), shared :: src(*)
+ complex(8), device :: dst(*)
+ integer(4), value :: nelems
+ end subroutine
+
+ attributes(device) subroutine tma_bulk_store_i4(src, dst, nelems)
+ !dir$ ignore_tkr (r) src, (r) dst
+ integer(4), shared :: src(*)
+ integer(4), device :: dst(*)
+ integer(4), value :: nelems
+ end subroutine
+
+ attributes(device) subroutine tma_bulk_store_i8(src, dst, nelems)
+ !dir$ ignore_tkr (r) src, (r) dst
+ integer(8), shared :: src(*)
+ integer(8), device :: dst(*)
+ integer(4), value :: nelems
+ end subroutine
+
+ attributes(device) subroutine tma_bulk_store_r2(src, dst, nelems)
+ !dir$ ignore_tkr (r) src, (r) dst
+ real(2), shared :: src(*)
+ real(2), device :: dst(*)
+ integer(4), value :: nelems
+ end subroutine
+
+ attributes(device) subroutine tma_bulk_store_r4(src, dst, nelems)
+ !dir$ ignore_tkr (r) src, (r) dst
+ real(4), shared :: src(*)
+ real(4), device :: dst(*)
+ integer(4), value :: nelems
+ end subroutine
+
+ attributes(device) subroutine tma_bulk_store_r8(src, dst, nelems)
+ !dir$ ignore_tkr (r) src, (r) dst
+ real(8), shared :: src(*)
+ real(8), device :: dst(*)
+ integer(4), value :: nelems
+ end subroutine
+ end interface
+
contains
attributes(device) subroutine syncthreads()
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index 99b1a2fc0cbf7..83b679f315e85 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -514,3 +514,95 @@ end subroutine
! CHECK-LABEL: func.func @_QPtest_barrier_try_wait_sleep()
! CHECK: %{{.*}} = nvvm.inline_ptx ".reg .pred p; mbarrier.try_wait.shared.b64 p, [%{{.*}}], %{{.*}}, %{{.*}}; selp.b32 %0, 1, 0, p;" ro(%{{.*}}, %{{.*}}, %{{.*}} : !llvm.ptr, i64, i32) -> i32
+
+attributes(global) subroutine test_tma_bulk_store_c4(c, n)
+ integer, value :: n
+ complex(4), device :: c(n)
+ complex(4), shared :: tmpa(1024)
+ integer(4) :: j, elem_count
+ call tma_bulk_store(tmpa, c(j), elem_count)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_c4
+! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
+! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
+! CHECK: nvvm.cp.async.bulk.wait_group 0
+
+attributes(global) subroutine test_tma_bulk_store_c8(c, n)
+ integer, value :: n
+ complex(8), device :: c(n)
+ complex(8), shared :: tmpa(1024)
+ integer(4) :: j, elem_count
+ call tma_bulk_store(tmpa, c(j), elem_count)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_c8
+! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
+! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
+! CHECK: nvvm.cp.async.bulk.wait_group 0
+
+attributes(global) subroutine test_tma_bulk_store_i4(c, n)
+ integer, value :: n
+ integer(4), device :: c(n)
+ integer(4), shared :: tmpa(1024)
+ integer(4) :: j, elem_count
+ call tma_bulk_store(tmpa, c(j), elem_count)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_i4
+! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
+! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
+! CHECK: nvvm.cp.async.bulk.wait_group 0
+
+attributes(global) subroutine test_tma_bulk_store_i8(c, n)
+ integer, value :: n
+ integer(8), device :: c(n)
+ integer(8), shared :: tmpa(1024)
+ integer(4) :: j, elem_count
+ call tma_bulk_store(tmpa, c(j), elem_count)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_i8
+! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
+! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
+! CHECK: nvvm.cp.async.bulk.wait_group 0
+
+
+attributes(global) subroutine test_tma_bulk_store_r2(c, n)
+ integer, value :: n
+ real(2), device :: c(n)
+ real(2), shared :: tmpa(1024)
+ integer(4) :: j, elem_count
+ call tma_bulk_store(tmpa, c(j), elem_count)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_r2
+! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
+! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
+! CHECK: nvvm.cp.async.bulk.wait_group 0
+
+attributes(global) subroutine test_tma_bulk_store_r4(c, n)
+ integer, value :: n
+ real(4), device :: c(n)
+ real(4), shared :: tmpa(1024)
+ integer(4) :: j, elem_count
+ call tma_bulk_store(tmpa, c(j), elem_count)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_r4
+! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
+! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
+! CHECK: nvvm.cp.async.bulk.wait_group 0
+
+attributes(global) subroutine test_tma_bulk_store_r8(c, n)
+ integer, value :: n
+ real(8), device :: c(n)
+ real(8), shared :: tmpa(1024)
+ integer(4) :: j, elem_count
+ call tma_bulk_store(tmpa, c(j), elem_count)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_r8
+! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
+! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
+! CHECK: nvvm.cp.async.bulk.wait_group 0
More information about the flang-commits
mailing list