[flang-commits] [flang] 848d865 - [flang][cuda] Set alignment for tma bulk store (#170558)

via flang-commits flang-commits at lists.llvm.org
Wed Dec 3 14:22:36 PST 2025


Author: Valentin Clement (バレンタイン クレメン)
Date: 2025-12-03T14:22:32-08:00
New Revision: 848d8657aec798b2630b8dfb57c009e7060d6d49

URL: https://github.com/llvm/llvm-project/commit/848d8657aec798b2630b8dfb57c009e7060d6d49
DIFF: https://github.com/llvm/llvm-project/commit/848d8657aec798b2630b8dfb57c009e7060d6d49.diff

LOG: [flang][cuda] Set alignment for tma bulk store (#170558)

Shared memory needs to be aligned like in the bulk load operation.

Added: 
    

Modified: 
    flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp
    flang/test/Lower/CUDA/cuda-device-proc.cuf

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp b/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp
index ae6120826f8d2..d324fe102ab7a 100644
--- a/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp
@@ -53,6 +53,8 @@ static const char __ldlu_r2x2[] = "__ldlu_r2x2_";
 static const char __ldlu_r4x4[] = "__ldlu_r4x4_";
 static const char __ldlu_r8x2[] = "__ldlu_r8x2_";
 
+static constexpr unsigned kTMAAlignment = 16;
+
 // CUDA specific intrinsic handlers.
 static constexpr IntrinsicHandler cudaHandlers[]{
     {"__ldca_i4x4",
@@ -1505,7 +1507,7 @@ static void genTMABulkLoad(fir::FirOpBuilder &builder, mlir::Location loc,
   mlir::Value size = mlir::arith::MulIOp::create(builder, loc, nelem, eleSize);
   auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext());
   barrier = builder.createConvert(loc, llvmPtrTy, barrier);
-  setAlignment(dst, 16);
+  setAlignment(dst, kTMAAlignment);
   dst = convertPtrToNVVMSpace(builder, loc, dst,
                               mlir::NVVM::NVVMMemorySpace::Shared);
   src = convertPtrToNVVMSpace(builder, loc, src,
@@ -1611,6 +1613,7 @@ static void genTMABulkStore(fir::FirOpBuilder &builder, mlir::Location loc,
                             mlir::Value src, mlir::Value dst, mlir::Value count,
                             mlir::Value eleSize) {
   mlir::Value size = mlir::arith::MulIOp::create(builder, loc, eleSize, count);
+  setAlignment(src, kTMAAlignment);
   src = convertPtrToNVVMSpace(builder, loc, src,
                               mlir::NVVM::NVVMMemorySpace::Shared);
   dst = convertPtrToNVVMSpace(builder, loc, dst,

diff  --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index 1e3c66307c334..20524b8f7de96 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -675,6 +675,7 @@ attributes(global) subroutine test_tma_bulk_store_c4(c, n)
 end subroutine
 
 ! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_c4
+! CHECK: cuf.shared_memory !fir.array<1024xcomplex<f32>> align 16 {bindc_name = "tmpa", uniq_name = "_QFtest_tma_bulk_store_c4Etmpa"} -> !fir.ref<!fir.array<1024xcomplex<f32>>>
 ! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
 ! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group;"
 ! CHECK: nvvm.cp.async.bulk.wait_group 0
@@ -688,6 +689,7 @@ attributes(global) subroutine test_tma_bulk_store_c8(c, n)
 end subroutine
 
 ! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_c8
+! CHECK: cuf.shared_memory !fir.array<1024xcomplex<f64>> align 16 {bindc_name = "tmpa", uniq_name = "_QFtest_tma_bulk_store_c8Etmpa"} -> !fir.ref<!fir.array<1024xcomplex<f64>>>
 ! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
 ! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group;"
 ! CHECK: nvvm.cp.async.bulk.wait_group 0
@@ -701,6 +703,7 @@ attributes(global) subroutine test_tma_bulk_store_i4(c, n)
 end subroutine
 
 ! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_i4
+! CHECK: cuf.shared_memory !fir.array<1024xi32> align 16 {bindc_name = "tmpa", uniq_name = "_QFtest_tma_bulk_store_i4Etmpa"} -> !fir.ref<!fir.array<1024xi32>>
 ! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
 ! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group;"
 ! CHECK: nvvm.cp.async.bulk.wait_group 0
@@ -714,6 +717,7 @@ attributes(global) subroutine test_tma_bulk_store_i8(c, n)
 end subroutine
 
 ! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_i8
+! CHECK: cuf.shared_memory !fir.array<1024xi64> align 16 {bindc_name = "tmpa", uniq_name = "_QFtest_tma_bulk_store_i8Etmpa"} -> !fir.ref<!fir.array<1024xi64>>
 ! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
 ! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group;"
 ! CHECK: nvvm.cp.async.bulk.wait_group 0
@@ -728,6 +732,7 @@ attributes(global) subroutine test_tma_bulk_store_r2(c, n)
 end subroutine
 
 ! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_r2
+! CHECK: cuf.shared_memory !fir.array<1024xf16> align 16 {bindc_name = "tmpa", uniq_name = "_QFtest_tma_bulk_store_r2Etmpa"} -> !fir.ref<!fir.array<1024xf16>>
 ! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
 ! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group;"
 ! CHECK: nvvm.cp.async.bulk.wait_group 0
@@ -741,6 +746,7 @@ attributes(global) subroutine test_tma_bulk_store_r4(c, n)
 end subroutine
 
 ! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_r4
+! CHECK: cuf.shared_memory !fir.array<1024xf32> align 16 {bindc_name = "tmpa", uniq_name = "_QFtest_tma_bulk_store_r4Etmpa"} -> !fir.ref<!fir.array<1024xf32>>
 ! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
 ! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group;"
 ! CHECK: nvvm.cp.async.bulk.wait_group 0
@@ -754,6 +760,7 @@ attributes(global) subroutine test_tma_bulk_store_r8(c, n)
 end subroutine
 
 ! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_r8
+! CHECK: cuf.shared_memory !fir.array<1024xf64> align 16 {bindc_name = "tmpa", uniq_name = "_QFtest_tma_bulk_store_r8Etmpa"} -> !fir.ref<!fir.array<1024xf64>>
 ! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
 ! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group;"
 ! CHECK: nvvm.cp.async.bulk.wait_group 0


        


More information about the flang-commits mailing list