[flang-commits] [flang] [flang][cuda] Add interfaces and lowering for tma_bulk_store (PR #165482)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Tue Oct 28 14:58:13 PDT 2025


https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/165482

>From af2e53c1c1aaa2621aa3540f69a1fcd02391b7f7 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Tue, 28 Oct 2025 14:35:44 -0700
Subject: [PATCH] [flang][cuda] Add interfaces and lowering for tma_bulk_store

---
 .../flang/Optimizer/Builder/IntrinsicCall.h   |   7 ++
 flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 114 ++++++++++++++++++
 flang/module/cudadevice.f90                   |  85 +++++++++++--
 flang/test/Lower/CUDA/cuda-device-proc.cuf    |  92 ++++++++++++++
 4 files changed, 289 insertions(+), 9 deletions(-)

diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
index f5ff6626da654..3407dd01dd504 100644
--- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
+++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
@@ -469,6 +469,13 @@ struct IntrinsicLibrary {
   void genTMABulkLoadR4(llvm::ArrayRef<fir::ExtendedValue>);
   void genTMABulkLoadR8(llvm::ArrayRef<fir::ExtendedValue>);
   void genTMABulkS2G(llvm::ArrayRef<fir::ExtendedValue>);
+  void genTMABulkStoreI4(llvm::ArrayRef<fir::ExtendedValue>);
+  void genTMABulkStoreI8(llvm::ArrayRef<fir::ExtendedValue>);
+  void genTMABulkStoreR2(llvm::ArrayRef<fir::ExtendedValue>);
+  void genTMABulkStoreR4(llvm::ArrayRef<fir::ExtendedValue>);
+  void genTMABulkStoreR8(llvm::ArrayRef<fir::ExtendedValue>);
+  void genTMABulkStoreC4(llvm::ArrayRef<fir::ExtendedValue>);
+  void genTMABulkStoreC8(llvm::ArrayRef<fir::ExtendedValue>);
   void genTMABulkWaitGroup(llvm::ArrayRef<fir::ExtendedValue>);
   mlir::Value genTrailz(mlir::Type, llvm::ArrayRef<mlir::Value>);
   fir::ExtendedValue genTransfer(mlir::Type,
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 65317599ecd35..53fe9c0d2f6f0 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -1098,6 +1098,34 @@ static constexpr IntrinsicHandler handlers[]{
      &I::genTMABulkS2G,
      {{{"src", asAddr}, {"dst", asAddr}, {"nbytes", asValue}}},
      /*isElemental=*/false},
+    {"tma_bulk_store_c4",
+     &I::genTMABulkStoreC4,
+     {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
+     /*isElemental=*/false},
+    {"tma_bulk_store_c8",
+     &I::genTMABulkStoreC8,
+     {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
+     /*isElemental=*/false},
+    {"tma_bulk_store_i4",
+     &I::genTMABulkStoreI4,
+     {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
+     /*isElemental=*/false},
+    {"tma_bulk_store_i8",
+     &I::genTMABulkStoreI8,
+     {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
+     /*isElemental=*/false},
+    {"tma_bulk_store_r2",
+     &I::genTMABulkStoreR2,
+     {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
+     /*isElemental=*/false},
+    {"tma_bulk_store_r4",
+     &I::genTMABulkStoreR4,
+     {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
+     /*isElemental=*/false},
+    {"tma_bulk_store_r8",
+     &I::genTMABulkStoreR8,
+     {{{"src", asAddr}, {"dst", asAddr}, {"count", asValue}}},
+     /*isElemental=*/false},
     {"tma_bulk_wait_group",
      &I::genTMABulkWaitGroup,
      {{}},
@@ -9430,6 +9458,92 @@ void IntrinsicLibrary::genTMABulkS2G(llvm::ArrayRef<fir::ExtendedValue> args) {
                                              builder.getI32IntegerAttr(0), {});
 }
 
+static void genTMABulkStore(fir::FirOpBuilder &builder, mlir::Location loc,
+                            mlir::Value src, mlir::Value dst, mlir::Value count,
+                            mlir::Value eleSize) {
+  mlir::Value size = mlir::arith::MulIOp::create(builder, loc, eleSize, count);
+  src = convertPtrToNVVMSpace(builder, loc, src,
+                              mlir::NVVM::NVVMMemorySpace::Shared);
+  dst = convertPtrToNVVMSpace(builder, loc, dst,
+                              mlir::NVVM::NVVMMemorySpace::Global);
+  mlir::NVVM::CpAsyncBulkSharedCTAToGlobalOp::create(builder, loc, dst, src,
+                                                     size, {}, {});
+  mlir::NVVM::InlinePtxOp::create(builder, loc, mlir::TypeRange{}, {}, {},
+                                  "cp.async.bulk.commit_group", {});
+  mlir::NVVM::CpAsyncBulkWaitGroupOp::create(builder, loc,
+                                             builder.getI32IntegerAttr(0), {});
+}
+
+// TMA_BULK_STORE_C4 (CUDA)
+void IntrinsicLibrary::genTMABulkStoreC4(
+    llvm::ArrayRef<fir::ExtendedValue> args) {
+  assert(args.size() == 3);
+  mlir::Value eleSize =
+      builder.createIntegerConstant(loc, builder.getI32Type(), 8);
+  genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+                  fir::getBase(args[2]), eleSize);
+}
+
+// TMA_BULK_STORE_C8 (CUDA)
+void IntrinsicLibrary::genTMABulkStoreC8(
+    llvm::ArrayRef<fir::ExtendedValue> args) {
+  assert(args.size() == 3);
+  mlir::Value eleSize =
+      builder.createIntegerConstant(loc, builder.getI32Type(), 16);
+  genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+                  fir::getBase(args[2]), eleSize);
+}
+
+// TMA_BULK_STORE_I4 (CUDA)
+void IntrinsicLibrary::genTMABulkStoreI4(
+    llvm::ArrayRef<fir::ExtendedValue> args) {
+  assert(args.size() == 3);
+  mlir::Value eleSize =
+      builder.createIntegerConstant(loc, builder.getI32Type(), 4);
+  genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+                  fir::getBase(args[2]), eleSize);
+}
+
+// TMA_BULK_STORE_I8 (CUDA)
+void IntrinsicLibrary::genTMABulkStoreI8(
+    llvm::ArrayRef<fir::ExtendedValue> args) {
+  assert(args.size() == 3);
+  mlir::Value eleSize =
+      builder.createIntegerConstant(loc, builder.getI32Type(), 8);
+  genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+                  fir::getBase(args[2]), eleSize);
+}
+
+// TMA_BULK_STORE_R2 (CUDA)
+void IntrinsicLibrary::genTMABulkStoreR2(
+    llvm::ArrayRef<fir::ExtendedValue> args) {
+  assert(args.size() == 3);
+  mlir::Value eleSize =
+      builder.createIntegerConstant(loc, builder.getI32Type(), 2);
+  genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+                  fir::getBase(args[2]), eleSize);
+}
+
+// TMA_BULK_STORE_R4 (CUDA)
+void IntrinsicLibrary::genTMABulkStoreR4(
+    llvm::ArrayRef<fir::ExtendedValue> args) {
+  assert(args.size() == 3);
+  mlir::Value eleSize =
+      builder.createIntegerConstant(loc, builder.getI32Type(), 4);
+  genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+                  fir::getBase(args[2]), eleSize);
+}
+
+// TMA_BULK_STORE_R8 (CUDA)
+void IntrinsicLibrary::genTMABulkStoreR8(
+    llvm::ArrayRef<fir::ExtendedValue> args) {
+  assert(args.size() == 3);
+  mlir::Value eleSize =
+      builder.createIntegerConstant(loc, builder.getI32Type(), 8);
+  genTMABulkStore(builder, loc, fir::getBase(args[0]), fir::getBase(args[1]),
+                  fir::getBase(args[2]), eleSize);
+}
+
 // TMA_BULK_WAIT_GROUP (CUDA)
 void IntrinsicLibrary::genTMABulkWaitGroup(
     llvm::ArrayRef<fir::ExtendedValue> args) {
diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90
index e6decbb96c271..59af58ddcd32e 100644
--- a/flang/module/cudadevice.f90
+++ b/flang/module/cudadevice.f90
@@ -2047,7 +2047,13 @@ attributes(device) subroutine tma_bulk_wait_group()
     end subroutine
   end interface
 
+  ! --------------------
+  ! Bulk load functions
+  ! --------------------
+
   ! Generic load, count is in bytes
+  ! -------------------------------
+
   interface
     attributes(device) subroutine tma_bulk_g2s(barrier, src, dst, nbytes)
       !dir$ ignore_tkr src, dst
@@ -2058,17 +2064,9 @@ attributes(device) subroutine tma_bulk_g2s(barrier, src, dst, nbytes)
     end subroutine
   end interface
 
-  interface
-    attributes(device) subroutine tma_bulk_s2g(src, dst, nbytes)
-      !dir$ ignore_tkr src, dst
-      integer(4), shared  :: src(*)
-      integer(4), device  :: dst(*)
-      integer(4), value   :: nbytes
-    end subroutine
-  end interface
-
   ! Load specific types, count is in elements
   ! -----------------------------------------
+
   interface tma_bulk_load
     attributes(device) subroutine tma_bulk_ldc4(barrier, src, dst, nelems)
       !dir$ ignore_tkr (r) src, (r) dst
@@ -2127,6 +2125,75 @@ attributes(device) subroutine tma_bulk_ldr8(barrier, src, dst, nelems)
     end subroutine
   end interface
 
+  ! --------------------
+  ! Bulk Store functions
+  ! --------------------
+
+  ! Generic store, count is in bytes
+  ! --------------------------------
+
+  interface
+    attributes(device) subroutine tma_bulk_s2g(src, dst, nbytes)
+      !dir$ ignore_tkr src, dst
+      integer(4), shared  :: src(*)
+      integer(4), device  :: dst(*)
+      integer(4), value   :: nbytes
+    end subroutine
+  end interface
+
+  ! Load specific types, count is in elements
+  ! -----------------------------------------
+
+  interface tma_bulk_store
+    attributes(device) subroutine tma_bulk_store_c4(src, dst, nelems)
+      !dir$ ignore_tkr (r) src, (r) dst
+      complex(4), shared :: src(*)
+      complex(4), device :: dst(*)
+      integer(4), value :: nelems
+    end subroutine
+
+    attributes(device) subroutine tma_bulk_store_c8(src, dst, nelems)
+      !dir$ ignore_tkr (r) src, (r) dst
+      complex(8), shared :: src(*)
+      complex(8), device :: dst(*)
+      integer(4), value :: nelems
+    end subroutine
+
+    attributes(device) subroutine tma_bulk_store_i4(src, dst, nelems)
+      !dir$ ignore_tkr (r) src, (r) dst
+      integer(4), shared :: src(*)
+      integer(4), device :: dst(*)
+      integer(4), value :: nelems
+    end subroutine
+
+    attributes(device) subroutine tma_bulk_store_i8(src, dst, nelems)
+      !dir$ ignore_tkr (r) src, (r) dst
+      integer(8), shared :: src(*)
+      integer(8), device :: dst(*)
+      integer(4), value :: nelems
+    end subroutine
+
+    attributes(device) subroutine tma_bulk_store_r2(src, dst, nelems)
+      !dir$ ignore_tkr (r) src, (r) dst
+      real(2), shared :: src(*)
+      real(2), device :: dst(*)
+      integer(4), value :: nelems
+    end subroutine
+
+    attributes(device) subroutine tma_bulk_store_r4(src, dst, nelems)
+      !dir$ ignore_tkr (r) src, (r) dst
+      real(4), shared :: src(*)
+      real(4), device :: dst(*)
+      integer(4), value :: nelems
+    end subroutine
+
+    attributes(device) subroutine tma_bulk_store_r8(src, dst, nelems)
+      !dir$ ignore_tkr (r) src, (r) dst
+      real(8), shared :: src(*)
+      real(8), device :: dst(*)
+      integer(4), value :: nelems
+    end subroutine
+  end interface
 
 contains
 
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index 71d3d1ef2e2e9..8f355217899b3 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -649,3 +649,95 @@ end subroutine
 ! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
 ! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xf64>>, !fir.ref<f64>, i32, !llvm.ptr)
 ! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
+
+attributes(global) subroutine test_tma_bulk_store_c4(c, n)
+  integer, value :: n
+  complex(4), device :: c(n)
+  complex(4), shared :: tmpa(1024)
+  integer(4) :: j, elem_count
+  call tma_bulk_store(tmpa, c(j), elem_count)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_c4
+! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
+! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
+! CHECK: nvvm.cp.async.bulk.wait_group 0
+
+attributes(global) subroutine test_tma_bulk_store_c8(c, n)
+  integer, value :: n
+  complex(8), device :: c(n)
+  complex(8), shared :: tmpa(1024)
+  integer(4) :: j, elem_count
+  call tma_bulk_store(tmpa, c(j), elem_count)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_c8
+! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
+! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
+! CHECK: nvvm.cp.async.bulk.wait_group 0
+
+attributes(global) subroutine test_tma_bulk_store_i4(c, n)
+  integer, value :: n
+  integer(4), device :: c(n)
+  integer(4), shared :: tmpa(1024)
+  integer(4) :: j, elem_count
+  call tma_bulk_store(tmpa, c(j), elem_count)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_i4
+! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
+! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
+! CHECK: nvvm.cp.async.bulk.wait_group 0
+
+attributes(global) subroutine test_tma_bulk_store_i8(c, n)
+  integer, value :: n
+  integer(8), device :: c(n)
+  integer(8), shared :: tmpa(1024)
+  integer(4) :: j, elem_count
+  call tma_bulk_store(tmpa, c(j), elem_count)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_i8
+! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
+! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
+! CHECK: nvvm.cp.async.bulk.wait_group 0
+
+
+attributes(global) subroutine test_tma_bulk_store_r2(c, n)
+  integer, value :: n
+  real(2), device :: c(n)
+  real(2), shared :: tmpa(1024)
+  integer(4) :: j, elem_count
+  call tma_bulk_store(tmpa, c(j), elem_count)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_r2
+! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
+! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
+! CHECK: nvvm.cp.async.bulk.wait_group 0
+
+attributes(global) subroutine test_tma_bulk_store_r4(c, n)
+  integer, value :: n
+  real(4), device :: c(n)
+  real(4), shared :: tmpa(1024)
+  integer(4) :: j, elem_count
+  call tma_bulk_store(tmpa, c(j), elem_count)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_r4
+! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
+! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
+! CHECK: nvvm.cp.async.bulk.wait_group 0
+
+attributes(global) subroutine test_tma_bulk_store_r8(c, n)
+  integer, value :: n
+  real(8), device :: c(n)
+  real(8), shared :: tmpa(1024)
+  integer(4) :: j, elem_count
+  call tma_bulk_store(tmpa, c(j), elem_count)
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_r8
+! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
+! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group"
+! CHECK: nvvm.cp.async.bulk.wait_group 0



More information about the flang-commits mailing list