[flang-commits] [flang] [flang][cuda] Add interfaces and lowering for tma_bulk_load (PR #165474)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Tue Oct 28 14:07:36 PDT 2025


Valentin Clement =?utf-8?b?KOODkOODrOODsw=?Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/165474 at github.com>


https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/165474

>From 17502dc89f49e047b8c0dee86da6518316f51040 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Tue, 28 Oct 2025 13:42:28 -0700
Subject: [PATCH 1/2] [flang][cuda] Add interfaces and lowering for
 tma_bulk_load

As defined in https://docs.nvidia.com/hpc-sdk/compilers/cuda-fortran-prog-guide/#load-and-store-functions-using-bulk-tma-operations
---
 .../flang/Optimizer/Builder/IntrinsicCall.h   |   7 +
 flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 136 ++++++++++++++++++
 flang/module/cudadevice.f90                   |  61 ++++++++
 flang/test/Lower/CUDA/cuda-device-proc.cuf    | 133 +++++++++++++++++
 4 files changed, 337 insertions(+)

diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
index ed0cbd3bdf16b..f5ff6626da654 100644
--- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
+++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
@@ -461,6 +461,13 @@ struct IntrinsicLibrary {
   mlir::Value genTime(mlir::Type, llvm::ArrayRef<mlir::Value>);
   void genTMABulkCommitGroup(llvm::ArrayRef<fir::ExtendedValue>);
   void genTMABulkG2S(llvm::ArrayRef<fir::ExtendedValue>);
+  void genTMABulkLoadC4(llvm::ArrayRef<fir::ExtendedValue>);
+  void genTMABulkLoadC8(llvm::ArrayRef<fir::ExtendedValue>);
+  void genTMABulkLoadI4(llvm::ArrayRef<fir::ExtendedValue>);
+  void genTMABulkLoadI8(llvm::ArrayRef<fir::ExtendedValue>);
+  void genTMABulkLoadR2(llvm::ArrayRef<fir::ExtendedValue>);
+  void genTMABulkLoadR4(llvm::ArrayRef<fir::ExtendedValue>);
+  void genTMABulkLoadR8(llvm::ArrayRef<fir::ExtendedValue>);
   void genTMABulkS2G(llvm::ArrayRef<fir::ExtendedValue>);
   void genTMABulkWaitGroup(llvm::ArrayRef<fir::ExtendedValue>);
   mlir::Value genTrailz(mlir::Type, llvm::ArrayRef<mlir::Value>);
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 0d225532f2460..c6efc8cfcb579 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -1045,6 +1045,55 @@ static constexpr IntrinsicHandler handlers[]{
        {"dst", asAddr},
        {"nbytes", asValue}}},
      /*isElemental=*/false},
+    {"tma_bulk_ldc4",
+     &I::genTMABulkLoadC4,
+     {{{"barrier", asAddr},
+       {"src", asAddr},
+       {"dst", asAddr},
+       {"nelems", asValue}}},
+     /*isElemental=*/false},
+    {"tma_bulk_ldc8",
+     &I::genTMABulkLoadC8,
+     {{{"barrier", asAddr},
+       {"src", asAddr},
+       {"dst", asAddr},
+       {"nelems", asValue}}},
+     /*isElemental=*/false},
+    {"tma_bulk_ldi4",
+     &I::genTMABulkLoadI4,
+     {{{"barrier", asAddr},
+       {"src", asAddr},
+       {"dst", asAddr},
+       {"nelems", asValue}}},
+     /*isElemental=*/false},
+    {"tma_bulk_ldi8",
+     &I::genTMABulkLoadI8,
+     {{{"barrier", asAddr},
+       {"src", asAddr},
+       {"dst", asAddr},
+       {"nelems", asValue}}},
+     /*isElemental=*/false},
+    {"tma_bulk_ldr2",
+     &I::genTMABulkLoadR2,
+     {{{"barrier", asAddr},
+       {"src", asAddr},
+       {"dst", asAddr},
+       {"nelems", asValue}}},
+     /*isElemental=*/false},
+    {"tma_bulk_ldr4",
+     &I::genTMABulkLoadR4,
+     {{{"barrier", asAddr},
+       {"src", asAddr},
+       {"dst", asAddr},
+       {"nelems", asValue}}},
+     /*isElemental=*/false},
+    {"tma_bulk_ldr8",
+     &I::genTMABulkLoadR8,
+     {{{"barrier", asAddr},
+       {"src", asAddr},
+       {"dst", asAddr},
+       {"nelems", asValue}}},
+     /*isElemental=*/false},
     {"tma_bulk_s2g",
      &I::genTMABulkS2G,
      {{{"src", asAddr}, {"dst", asAddr}, {"nbytes", asValue}}},
@@ -9278,6 +9327,93 @@ void IntrinsicLibrary::genTMABulkG2S(llvm::ArrayRef<fir::ExtendedValue> args) {
       builder, loc, dst, src, barrier, fir::getBase(args[3]), {}, {});
 }
 
+static void genTMABulkLoad(fir::FirOpBuilder &builder, mlir::Location loc,
+                           mlir::Value barrier, mlir::Value src,
+                           mlir::Value dst, mlir::Value nelem,
+                           mlir::Value eleSize) {
+  mlir::Value size = mlir::arith::MulIOp::create(builder, loc, nelem, eleSize);
+  auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext());
+  barrier = builder.createConvert(loc, llvmPtrTy, barrier);
+  mlir::NVVM::InlinePtxOp::create(
+      builder, loc, mlir::TypeRange{}, {dst, src, size, barrier}, {},
+      "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], "
+      "[%1], %2, [%3];",
+      {});
+  mlir::NVVM::InlinePtxOp::create(
+      builder, loc, mlir::TypeRange{}, {barrier, size}, {},
+      "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;", {});
+}
+
+// TMA_BULK_LOADC4
+void IntrinsicLibrary::genTMABulkLoadC4(
+    llvm::ArrayRef<fir::ExtendedValue> args) {
+  assert(args.size() == 4);
+  mlir::Value eleSize =
+      builder.createIntegerConstant(loc, builder.getI32Type(), 8);
+  genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+                 fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
+}
+
+// TMA_BULK_LOADC8
+void IntrinsicLibrary::genTMABulkLoadC8(
+    llvm::ArrayRef<fir::ExtendedValue> args) {
+  assert(args.size() == 4);
+  mlir::Value eleSize =
+      builder.createIntegerConstant(loc, builder.getI32Type(), 16);
+  genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+                 fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
+}
+
+// TMA_BULK_LOADI4
+void IntrinsicLibrary::genTMABulkLoadI4(
+    llvm::ArrayRef<fir::ExtendedValue> args) {
+  assert(args.size() == 4);
+  mlir::Value eleSize =
+      builder.createIntegerConstant(loc, builder.getI32Type(), 4);
+  genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+                 fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
+}
+
+// TMA_BULK_LOADI8
+void IntrinsicLibrary::genTMABulkLoadI8(
+    llvm::ArrayRef<fir::ExtendedValue> args) {
+  assert(args.size() == 4);
+  mlir::Value eleSize =
+      builder.createIntegerConstant(loc, builder.getI32Type(), 8);
+  genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+                 fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
+}
+
+// TMA_BULK_LOADR2
+void IntrinsicLibrary::genTMABulkLoadR2(
+    llvm::ArrayRef<fir::ExtendedValue> args) {
+  assert(args.size() == 4);
+  mlir::Value eleSize =
+      builder.createIntegerConstant(loc, builder.getI32Type(), 2);
+  genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+                 fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
+}
+
+// TMA_BULK_LOADR4
+void IntrinsicLibrary::genTMABulkLoadR4(
+    llvm::ArrayRef<fir::ExtendedValue> args) {
+  assert(args.size() == 4);
+  mlir::Value eleSize =
+      builder.createIntegerConstant(loc, builder.getI32Type(), 4);
+  genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+                 fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
+}
+
+// TMA_BULK_LOADR8
+void IntrinsicLibrary::genTMABulkLoadR8(
+    llvm::ArrayRef<fir::ExtendedValue> args) {
+  assert(args.size() == 4);
+  mlir::Value eleSize =
+      builder.createIntegerConstant(loc, builder.getI32Type(), 8);
+  genTMABulkLoad(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+                 fir::getBase(args[2]), fir::getBase(args[3]), eleSize);
+}
+
 // TMA_BULK_S2G (CUDA)
 void IntrinsicLibrary::genTMABulkS2G(llvm::ArrayRef<fir::ExtendedValue> args) {
   assert(args.size() == 3);
diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90
index ea54c974c9e7c..ae8ebf7c9562d 100644
--- a/flang/module/cudadevice.f90
+++ b/flang/module/cudadevice.f90
@@ -2067,6 +2067,67 @@ attributes(device) subroutine tma_bulk_s2g(src, dst, nbytes)
     end subroutine
   end interface
 
+  ! Load specific types, count is in elements
+  ! -----------------------------------------
+  interface tma_bulk_load
+    attributes(device) subroutine tma_bulk_ldi4(barrier, src, dst, nelems)
+      !dir$ ignore_tkr (r) src, (r) dst
+      integer(8), shared :: barrier
+      integer(4), device :: src(*)
+      integer(4), shared :: dst(*)
+      integer(4), value :: nelems
+    end subroutine
+
+    attributes(device) subroutine tma_bulk_ldi8(barrier, src, dst, nelems)
+      !dir$ ignore_tkr (r) src, (r) dst
+      integer(8), shared :: barrier
+      integer(8), device :: src(*)
+      integer(8), shared :: dst(*)
+      integer(4), value :: nelems
+    end subroutine
+
+    attributes(device) subroutine tma_bulk_ldr2(barrier, src, dst, nelems)
+      !dir$ ignore_tkr (r) src, (r) dst
+      integer(8), shared :: barrier
+      real(2), device :: src(*)
+      real(2), shared :: dst(*)
+      integer(4), value :: nelems
+    end subroutine
+
+    attributes(device) subroutine tma_bulk_ldr4(barrier, src, dst, nelems)
+      !dir$ ignore_tkr (r) src, (r) dst
+      integer(8), shared :: barrier
+      real(4), device :: src(*)
+      real(4), shared :: dst(*)
+      integer(4), value :: nelems
+    end subroutine
+
+    attributes(device) subroutine tma_bulk_ldr8(barrier, src, dst, nelems)
+      !dir$ ignore_tkr (r) src, (r) dst
+      integer(8), shared :: barrier
+      real(8), device :: src(*)
+      real(8), shared :: dst(*)
+      integer(4), value :: nelems
+    end subroutine
+
+    attributes(device) subroutine tma_bulk_ldc4(barrier, src, dst, nelems)
+      !dir$ ignore_tkr (r) src, (r) dst
+      integer(8), shared :: barrier
+      complex(4), device :: src(*)
+      complex(4), shared :: dst(*)
+      integer(4), value :: nelems
+    end subroutine
+
+    attributes(device) subroutine tma_bulk_ldc8(barrier, src, dst, nelems)
+      !dir$ ignore_tkr (r) src, (r) dst
+      integer(8), shared :: barrier
+      complex(8), device :: src(*)
+      complex(8), shared :: dst(*)
+      integer(4), value :: nelems
+    end subroutine
+  end interface
+
+
 contains
 
   attributes(device) subroutine syncthreads()
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index 99b1a2fc0cbf7..36d5fb150ad32 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -514,3 +514,136 @@ end subroutine
 
 ! CHECK-LABEL: func.func @_QPtest_barrier_try_wait_sleep()
 ! CHECK: %{{.*}} = nvvm.inline_ptx ".reg .pred p; mbarrier.try_wait.shared.b64 p, [%{{.*}}], %{{.*}}, %{{.*}}; selp.b32 %0, 1, 0, p;" ro(%{{.*}}, %{{.*}}, %{{.*}} : !llvm.ptr, i64, i32) -> i32
+
+attributes(global) subroutine test_tma_bulk_load_c4(a, n)
+  integer(8), shared :: barrier1
+  integer, value :: n
+  complex(4), device :: r8(n)
+  complex(4), shared :: tmp(1024)
+  integer(4) :: j, elem_count
+  call tma_bulk_load(barrier1, r8(j), tmp, elem_count)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_c4
+! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_c4Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
+! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_c4Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
+! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 8 : i32
+! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
+! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
+! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xcomplex<f32>>>, !fir.ref<complex<f32>>, i32, !llvm.ptr)
+! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
+
+attributes(global) subroutine test_tma_bulk_load_c8(a, n)
+  integer(8), shared :: barrier1
+  integer, value :: n
+  complex(8), device :: r8(n)
+  complex(8), shared :: tmp(1024)
+  integer(4) :: j, elem_count
+  call tma_bulk_load(barrier1, r8(j), tmp, elem_count)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_c8
+! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_c8Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
+! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_c8Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
+! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 16 : i32
+! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
+! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
+! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xcomplex<f64>>>, !fir.ref<complex<f64>>, i32, !llvm.ptr)
+! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
+
+attributes(global) subroutine test_tma_bulk_load_i4(a, n)
+  integer(8), shared :: barrier1
+  integer, value :: n
+  integer(4), device :: r8(n)
+  integer(4), shared :: tmp(1024)
+  integer(4) :: j, elem_count
+  call tma_bulk_load(barrier1, r8(j), tmp, elem_count)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_i4
+! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_i4Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
+! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_i4Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
+! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 4 : i32
+! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
+! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
+! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xi32>>, !fir.ref<i32>, i32, !llvm.ptr)
+! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
+
+attributes(global) subroutine test_tma_bulk_load_i8(a, n)
+  integer(8), shared :: barrier1
+  integer, value :: n
+  integer(8), device :: r8(n)
+  integer(8), shared :: tmp(1024)
+  integer(4) :: j, elem_count
+  call tma_bulk_load(barrier1, r8(j), tmp, elem_count)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_i8
+! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_i8Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
+! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_i8Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
+! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 8 : i32
+! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
+! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
+! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xi64>>, !fir.ref<i64>, i32, !llvm.ptr)
+! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
+
+attributes(global) subroutine test_tma_bulk_load_r2(a, n)
+  integer(8), shared :: barrier1
+  integer, value :: n
+  real(2), device :: r8(n)
+  real(2), shared :: tmp(1024)
+  integer(4) :: j, elem_count
+  call tma_bulk_load(barrier1, r8(j), tmp, elem_count)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_r2
+! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_r2Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
+! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_r2Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
+! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 2 : i32
+! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
+! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
+! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xf16>>, !fir.ref<f16>, i32, !llvm.ptr)
+! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
+
+attributes(global) subroutine test_tma_bulk_load_r4(a, n)
+  integer(8), shared :: barrier1
+  integer, value :: n
+  real(4), device :: r8(n)
+  real(4), shared :: tmp(1024)
+  integer(4) :: j, elem_count
+  call tma_bulk_load(barrier1, r8(j), tmp, elem_count)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_r4
+! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_r4Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
+! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_r4Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
+! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 4 : i32
+! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
+! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
+! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xf32>>, !fir.ref<f32>, i32, !llvm.ptr)
+! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
+
+attributes(global) subroutine test_tma_bulk_load_r8(a, n)
+  integer(8), shared :: barrier1
+  integer, value :: n
+  real(8), device :: r8(n)
+  real(8), shared :: tmp(1024)
+  integer(4) :: j, elem_count
+  call tma_bulk_load(barrier1, r8(j), tmp, elem_count)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_tma_bulk_load_r8
+! CHECK: %[[BARRIER:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<shared>, uniq_name = "_QFtest_tma_bulk_load_r8Ebarrier1"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
+! CHECK: %[[ELEM_COUNT:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_tma_bulk_load_r8Eelem_count"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %[[COUNT:.*]] = fir.load %[[ELEM_COUNT]]#0 : !fir.ref<i32>
+! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 8 : i32
+! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
+! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
+! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xf64>>, !fir.ref<f64>, i32, !llvm.ptr)
+! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)

>From 71b2e7e485fbc8b9fe49aa28e3a81d005ee65a4d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Valentin=20Clement=20=28=E3=83=90=E3=83=AC=E3=83=B3?=
 =?UTF-8?q?=E3=82=BF=E3=82=A4=E3=83=B3=20=E3=82=AF=E3=83=AC=E3=83=A1?=
 =?UTF-8?q?=E3=83=B3=29?= <clementval at gmail.com>
Date: Tue, 28 Oct 2025 14:07:28 -0700
Subject: [PATCH 2/2] Update flang/module/cudadevice.f90

---
 flang/module/cudadevice.f90 | 32 ++++++++++++++++----------------
 1 file changed, 16 insertions(+), 16 deletions(-)

diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90
index ae8ebf7c9562d..e6decbb96c271 100644
--- a/flang/module/cudadevice.f90
+++ b/flang/module/cudadevice.f90
@@ -2070,6 +2070,22 @@ attributes(device) subroutine tma_bulk_s2g(src, dst, nbytes)
   ! Load specific types, count is in elements
   ! -----------------------------------------
   interface tma_bulk_load
+    attributes(device) subroutine tma_bulk_ldc4(barrier, src, dst, nelems)
+      !dir$ ignore_tkr (r) src, (r) dst
+      integer(8), shared :: barrier
+      complex(4), device :: src(*)
+      complex(4), shared :: dst(*)
+      integer(4), value :: nelems
+    end subroutine
+
+    attributes(device) subroutine tma_bulk_ldc8(barrier, src, dst, nelems)
+      !dir$ ignore_tkr (r) src, (r) dst
+      integer(8), shared :: barrier
+      complex(8), device :: src(*)
+      complex(8), shared :: dst(*)
+      integer(4), value :: nelems
+    end subroutine
+  
     attributes(device) subroutine tma_bulk_ldi4(barrier, src, dst, nelems)
       !dir$ ignore_tkr (r) src, (r) dst
       integer(8), shared :: barrier
@@ -2109,22 +2125,6 @@ attributes(device) subroutine tma_bulk_ldr8(barrier, src, dst, nelems)
       real(8), shared :: dst(*)
       integer(4), value :: nelems
     end subroutine
-
-    attributes(device) subroutine tma_bulk_ldc4(barrier, src, dst, nelems)
-      !dir$ ignore_tkr (r) src, (r) dst
-      integer(8), shared :: barrier
-      complex(4), device :: src(*)
-      complex(4), shared :: dst(*)
-      integer(4), value :: nelems
-    end subroutine
-
-    attributes(device) subroutine tma_bulk_ldc8(barrier, src, dst, nelems)
-      !dir$ ignore_tkr (r) src, (r) dst
-      integer(8), shared :: barrier
-      complex(8), device :: src(*)
-      complex(8), shared :: dst(*)
-      integer(4), value :: nelems
-    end subroutine
   end interface
 
 



More information about the flang-commits mailing list