[flang-commits] [flang] [flang][cuda] Add interfaces and lowering for tma_bulk_load (PR #165474)
Valentin Clement バレンタイン クレメン via flang-commits
flang-commits at lists.llvm.org
Tue Oct 28 14:07:36 PDT 2025
Valentin Clement =?utf-8?b?KOODkOODrOODsw=?Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/165474 at github.com>
https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/165474
>From 17502dc89f49e047b8c0dee86da6518316f51040 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Tue, 28 Oct 2025 13:42:28 -0700
Subject: [PATCH 1/2] [flang][cuda] Add interfaces and lowering for
tma_bulk_load
As defined in https://docs.nvidia.com/hpc-sdk/compilers/cuda-fortran-prog-guide/#load-and-store-functions-using-bulk-tma-operations
---
.../flang/Optimizer/Builder/IntrinsicCall.h | 7 +
flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 136 ++++++++++++++++++
flang/module/cudadevice.f90 | 61 ++++++++
flang/test/Lower/CUDA/cuda-device-proc.cuf | 133 +++++++++++++++++
4 files changed, 337 insertions(+)
diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
index ed0cbd3bdf16b..f5ff6626da654 100644
--- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
+++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
@@ -461,6 +461,13 @@ struct IntrinsicLibrary {
mlir::Value genTime(mlir::Type, llvm::ArrayRef<mlir::Value>);
void genTMABulkCommitGroup(llvm::ArrayRef<fir::ExtendedValue>);
void genTMABulkG2S(llvm::ArrayRef<fir::ExtendedValue>);
+ void genTMABulkLoadC4(llvm::ArrayRef<fir::ExtendedValue>);
+ void genTMABulkLoadC8(llvm::ArrayRef<fir::ExtendedValue>);
+ void genTMABulkLoadI4(llvm::ArrayRef<fir::ExtendedValue>);
+ void genTMABulkLoadI8(llvm::ArrayRef<fir::ExtendedValue>);
+ void genTMABulkLoadR2(llvm::ArrayRef<fir::ExtendedValue>);
+ void genTMABulkLoadR4(llvm::ArrayRef<fir::ExtendedValue>);
+ void genTMABulkLoadR8(llvm::ArrayRef<fir::ExtendedValue>);
void genTMABulkS2G(llvm::ArrayRef<fir::ExtendedValue>);
void genTMABulkWaitGroup(llvm::ArrayRef<fir::ExtendedValue>);
mlir::Value genTrailz(mlir::Type, llvm::ArrayRef<mlir::Value>);
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 0d225532f2460..c6efc8cfcb579 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -1045,6 +1045,55 @@ static constexpr IntrinsicHandler handlers[]{
{"dst", asAddr},
{"nbytes", asValue}}},
/*isElemental=*/false},
+ {"tma_bulk_ldc4",
+ &I::genTMABulkLoadC4,
+ {{{"barrier", asAddr},
+ {"src", asAddr},
+ {"dst", asAddr},
+ {"nelems", asValue}}},
+ /*isElemental=*/false},
+ {"tma_bulk_ldc8",
+ &I::genTMABulkLoadC8,
+ {{{"barrier", asAddr},
+ {"src", asAddr},
+ {"dst", asAddr},
+ {"nelems", asValue}}},
+ /*isElemental=*/false},
+ {"tma_bulk_ldi4",
+ &I::genTMABulkLoadI4,
+ {{{"barrier", asAddr},
+ {"src", asAddr},
+ {"dst", asAddr},
+ {"nelems", asValue}}},
+ /*isElemental=*/false},
+ {"tma_bulk_ldi8",
+ &I::genTMABulkLoadI8,
+ {{{"barrier", asAddr},
+ {"src", asAddr},
+ {"dst", asAddr},
+ {"nelems", asValue}}},
+ /*isElemental=*/false},
+ {"tma_bulk_ldr2",
+ &I::genTMABulkLoadR2,
+ {{{"barrier", asAddr},
+ {"src", asAddr},
+ {"dst", asAddr},
+ {"nelems", asValue}}},
+ /*isElemental=*/false},
+ {"tma_bulk_ldr4",
+ &I::genTMABulkLoadR4,
+ {{{"barrier", asAddr},
+ {"src", asAddr},
+ {"dst", asAddr},
+ {"nelems", asValue}}},
+ /*isElemental=*/false},
+ {"tma_bulk_ldr8",
+ &I::genTMABulkLoadR8,
+ {{{"barrier", asAddr},
+ {"src", asAddr},
+ {"dst", asAddr},
+ {"nelems", asValue}}},
+ /*isElemental=*/false},
{"tma_bulk_s2g",
&I::genTMABulkS2G,
{{{"src", asAddr}, {"dst", asAddr}, {"nbytes", asValue}}},
@@ -9278,6 +9327,93 @@ void IntrinsicLibrary::genTMABulkG2S(llvm::ArrayRef<fir::ExtendedValue> args) {
builder, loc, dst, src, barrier, fir::getBase(args[3]), {}, {});
}
+static void genTMABulkLoad(fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::Value barrier, mlir::Value src,
+ mlir::Value dst, mlir::Value nelem,
+ mlir::Value eleSize) {
+ mlir::Value size = mlir::arith::MulIOp::create(builder, loc, nelem, eleSize);
+ auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext());
+ barrier = builder.createConvert(loc, llvmPtrTy, barrier);
+ mlir::NVVM::InlinePtxOp::create(
+ builder, loc, mlir::TypeRange{}, {dst, src, size, barrier}, {},
+ "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], "
+ "[%1], %2, [%3];",
+ {});
+ mlir::NVVM::InlinePtxOp::create(
+ builder, loc, mlir::TypeRange{}, {barrier, size}, {},
+ "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;", {});
+}
+
+// TMA_BULK_LOADC4
+void IntrinsicLibrary::genTMABulkLoadC4(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 4);
+ mlir::Value eleSize =
+ builder.createIntegerConstant(loc, builder.getI32Type(), 8);
+ genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+ fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
+}
+
+// TMA_BULK_LOADC8
+void IntrinsicLibrary::genTMABulkLoadC8(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 4);
+ mlir::Value eleSize =
+ builder.createIntegerConstant(loc, builder.getI32Type(), 16);
+ genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+ fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
+}
+
+// TMA_BULK_LOADI4
+void IntrinsicLibrary::genTMABulkLoadI4(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 4);
+ mlir::Value eleSize =
+ builder.createIntegerConstant(loc, builder.getI32Type(), 4);
+ genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+ fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
+}
+
+// TMA_BULK_LOADI8
+void IntrinsicLibrary::genTMABulkLoadI8(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 4);
+ mlir::Value eleSize =
+ builder.createIntegerConstant(loc, builder.getI32Type(), 8);
+ genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+ fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
+}
+
+// TMA_BULK_LOADR2
+void IntrinsicLibrary::genTMABulkLoadR2(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 4);
+ mlir::Value eleSize =
+ builder.createIntegerConstant(loc, builder.getI32Type(), 2);
+ genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+ fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
+}
+
+// TMA_BULK_LOADR4
+void IntrinsicLibrary::genTMABulkLoadR4(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 4);
+ mlir::Value eleSize =
+ builder.createIntegerConstant(loc, builder.getI32Type(), 4);
+ genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+ fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
+}
+
+// TMA_BULK_LOADR8
+void IntrinsicLibrary::genTMABulkLoadR8(
+ llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 4);
+ mlir::Value eleSize =
+ builder.createIntegerConstant(loc, builder.getI32Type(), 8);
+ genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+ fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
+}
+
// TMA_BULK_S2G (CUDA)
void IntrinsicLibrary::genTMABulkS2G(llvm::ArrayRef<fir::ExtendedValue> args) {
assert(args.size() == 3);
diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90
index ea54c974c9e7c..ae8ebf7c9562d 100644
--- a/flang/module/cudadevice.f90
+++ b/flang/module/cudadevice.f90
@@ -2067,6 +2067,67 @@ attributes(device) subroutine tma_bulk_s2g(src, dst, nbytes)
end subroutine
end interface
+ ! Load specific types, count is in elements
+ ! -----------------------------------------
+ interface tma_bulk_load
+ attributes(device) subroutine tma_bulk_ldi4(barrier, src, dst, nelems)
+ !dir$ ignore_tkr (r) src, (r) dst
+ integer(8), shared :: barrier
+ integer(4), device :: src(*)
+ integer(4), shared :: dst(*)
+ integer(4), value :: nelems
+ end subroutine
+
+ attributes(device) subroutine tma_bulk_ldi8(barrier, src, dst, nelems)
+ !dir$ ignore_tkr (r) src, (r) dst
+ integer(8), shared :: barrier
+ integer(8), device :: src(*)
+ integer(8), shared :: dst(*)
+ integer(4), value :: nelems
+ end subroutine
+
+ attributes(device) subroutine tma_bulk_ldr2(barrier, src, dst, nelems)
+ !dir$ ignore_tkr (r) src, (r) dst
+ integer(8), shared :: barrier
+ real(2), device :: src(*)
+ real(2), shared :: dst(*)
+ integer(4), value :: nelems
+ end subroutine
+
+ attributes(device) subroutine tma_bulk_ldr4(barrier, src, dst, nelems)
+ !dir$ ignore_tkr (r) src, (r) dst
+ integer(8), shared :: barrier
+ real(4), device :: src(*)
+ real(4), shared :: dst(*)
+ integer(4), value :: nelems
+ end subroutine
+
+ attributes(device) subroutine tma_bulk_ldr8(barrier, src, dst, nelems)
+ !dir$ ignore_tkr (r) src, (r) dst
+ integer(8), shared :: barrier
+ real(8), device :: src(*)
+ real(8), shared :: dst(*)
+ integer(4), value :: nelems
+ end subroutine
+
+ attributes(device) subroutine tma_bulk_ldc4(barrier, src, dst, nelems)
+ !dir$ ignore_tkr (r) src, (r) dst
+ integer(8), shared :: barrier
+ complex(4), device :: src(*)
+ complex(4), shared :: dst(*)
+ integer(4), value :: nelems
+ end subroutine
+
+ attributes(device) subroutine tma_bulk_ldc8(barrier, src, dst, nelems)
+ !dir$ ignore_tkr (r) src, (r) dst
+ integer(8), shared :: barrier
+ complex(8), device :: src(*)
+ complex(8), shared :: dst(*)
+ integer(4), value :: nelems
+ end subroutine
+ end interface
+
+
contains
attributes(device) subroutine syncthreads()
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index 99b1a2fc0cbf7..36d5fb150ad32 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -514,3 +514,136 @@ end subroutine
! CHECK-LABEL: func.func @_QPtest_barrier_try_wait_sleep()
! CHECK: %{{.*}} = nvvm.inline_ptx ".reg .pred p; mbarrier.try_wait.shared.b64 p, [%{{.*}}], %{{.*}}, %{{.*}}; selp.b32 %0, 1, 0, p;" ro(%{{.*}}, %{{.*}}, %{{.*}} : !llvm.ptr, i64, i32) -> i32
+
+attributes(global) subroutine test_tma_bulk_load_c4(a, n)
+ integer(8), shared :: barrier1
+ integer, value :: n
+ complex(4), device :: r8(n)
+ complex(4), shared :: tmp(1024)
+ integer(4) :: j, elem_count
+ call tma_bulk_load(barrier1, r8(j), tmp, elem_count)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_c4
+! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_c4Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
+! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_c4Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
+! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 8 : i32
+! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
+! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
+! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xcomplex<f32>>>, !fir.ref<complex<f32>>, i32, !llvm.ptr)
+! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
+
+attributes(global) subroutine test_tma_bulk_load_c8(a, n)
+ integer(8), shared :: barrier1
+ integer, value :: n
+ complex(8), device :: r8(n)
+ complex(8), shared :: tmp(1024)
+ integer(4) :: j, elem_count
+ call tma_bulk_load(barrier1, r8(j), tmp, elem_count)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_c8
+! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_c8Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
+! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_c8Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
+! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 16 : i32
+! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
+! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
+! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xcomplex<f64>>>, !fir.ref<complex<f64>>, i32, !llvm.ptr)
+! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
+
+attributes(global) subroutine test_tma_bulk_load_i4(a, n)
+ integer(8), shared :: barrier1
+ integer, value :: n
+ integer(4), device :: r8(n)
+ integer(4), shared :: tmp(1024)
+ integer(4) :: j, elem_count
+ call tma_bulk_load(barrier1, r8(j), tmp, elem_count)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_i4
+! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_i4Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
+! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_i4Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
+! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 4 : i32
+! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
+! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
+! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xi32>>, !fir.ref<i32>, i32, !llvm.ptr)
+! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
+
+attributes(global) subroutine test_tma_bulk_load_i8(a, n)
+ integer(8), shared :: barrier1
+ integer, value :: n
+ integer(8), device :: r8(n)
+ integer(8), shared :: tmp(1024)
+ integer(4) :: j, elem_count
+ call tma_bulk_load(barrier1, r8(j), tmp, elem_count)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_i8
+! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_i8Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
+! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_i8Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
+! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 8 : i32
+! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
+! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
+! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xi64>>, !fir.ref<i64>, i32, !llvm.ptr)
+! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
+
+attributes(global) subroutine test_tma_bulk_load_r2(a, n)
+ integer(8), shared :: barrier1
+ integer, value :: n
+ real(2), device :: r8(n)
+ real(2), shared :: tmp(1024)
+ integer(4) :: j, elem_count
+ call tma_bulk_load(barrier1, r8(j), tmp, elem_count)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_r2
+! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_r2Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
+! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_r2Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
+! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 2 : i32
+! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
+! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
+! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xf16>>, !fir.ref<f16>, i32, !llvm.ptr)
+! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
+
+attributes(global) subroutine test_tma_bulk_load_r4(a, n)
+ integer(8), shared :: barrier1
+ integer, value :: n
+ real(4), device :: r8(n)
+ real(4), shared :: tmp(1024)
+ integer(4) :: j, elem_count
+ call tma_bulk_load(barrier1, r8(j), tmp, elem_count)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_r4
+! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_r4Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
+! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_r4Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
+! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 4 : i32
+! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
+! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
+! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xf32>>, !fir.ref<f32>, i32, !llvm.ptr)
+! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
+
+attributes(global) subroutine test_tma_bulk_load_r8(a, n)
+ integer(8), shared :: barrier1
+ integer, value :: n
+ real(8), device :: r8(n)
+ real(8), shared :: tmp(1024)
+ integer(4) :: j, elem_count
+ call tma_bulk_load(barrier1, r8(j), tmp, elem_count)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_r8
+! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_r8Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
+! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_r8Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
+! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 8 : i32
+! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
+! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
+! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xf64>>, !fir.ref<f64>, i32, !llvm.ptr)
+! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
>From 71b2e7e485fbc8b9fe49aa28e3a81d005ee65a4d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Valentin=20Clement=20=28=E3=83=90=E3=83=AC=E3=83=B3?=
=?UTF-8?q?=E3=82=BF=E3=82=A4=E3=83=B3=20=E3=82=AF=E3=83=AC=E3=83=A1?=
=?UTF-8?q?=E3=83=B3=29?= <clementval at gmail.com>
Date: Tue, 28 Oct 2025 14:07:28 -0700
Subject: [PATCH 2/2] Update flang/module/cudadevice.f90
---
flang/module/cudadevice.f90 | 32 ++++++++++++++++----------------
1 file changed, 16 insertions(+), 16 deletions(-)
diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90
index ae8ebf7c9562d..e6decbb96c271 100644
--- a/flang/module/cudadevice.f90
+++ b/flang/module/cudadevice.f90
@@ -2070,6 +2070,22 @@ attributes(device) subroutine tma_bulk_s2g(src, dst, nbytes)
! Load specific types, count is in elements
! -----------------------------------------
interface tma_bulk_load
+ attributes(device) subroutine tma_bulk_ldc4(barrier, src, dst, nelems)
+ !dir$ ignore_tkr (r) src, (r) dst
+ integer(8), shared :: barrier
+ complex(4), device :: src(*)
+ complex(4), shared :: dst(*)
+ integer(4), value :: nelems
+ end subroutine
+
+ attributes(device) subroutine tma_bulk_ldc8(barrier, src, dst, nelems)
+ !dir$ ignore_tkr (r) src, (r) dst
+ integer(8), shared :: barrier
+ complex(8), device :: src(*)
+ complex(8), shared :: dst(*)
+ integer(4), value :: nelems
+ end subroutine
+
attributes(device) subroutine tma_bulk_ldi4(barrier, src, dst, nelems)
!dir$ ignore_tkr (r) src, (r) dst
integer(8), shared :: barrier
@@ -2109,22 +2125,6 @@ attributes(device) subroutine tma_bulk_ldr8(barrier, src, dst, nelems)
real(8), shared :: dst(*)
integer(4), value :: nelems
end subroutine
-
- attributes(device) subroutine tma_bulk_ldc4(barrier, src, dst, nelems)
- !dir$ ignore_tkr (r) src, (r) dst
- integer(8), shared :: barrier
- complex(4), device :: src(*)
- complex(4), shared :: dst(*)
- integer(4), value :: nelems
- end subroutine
-
- attributes(device) subroutine tma_bulk_ldc8(barrier, src, dst, nelems)
- !dir$ ignore_tkr (r) src, (r) dst
- integer(8), shared :: barrier
- complex(8), device :: src(*)
- complex(8), shared :: dst(*)
- integer(4), value :: nelems
- end subroutine
end interface
More information about the flang-commits
mailing list