[flang-commits] [flang] e52aece - [flang][cuda] Do not use address cast for src and dst in TMA bulk load (#170564)

via flang-commits flang-commits at lists.llvm.org
Wed Dec 3 14:39:02 PST 2025


Author: Valentin Clement (バレンタイン クレメン)
Date: 2025-12-03T22:38:58Z
New Revision: e52aece16b98a2fba9b84f411bd4c5f35dfe60c1

URL: https://github.com/llvm/llvm-project/commit/e52aece16b98a2fba9b84f411bd4c5f35dfe60c1
DIFF: https://github.com/llvm/llvm-project/commit/e52aece16b98a2fba9b84f411bd4c5f35dfe60c1.diff

LOG: [flang][cuda] Do not use address cast for src and dst in TMA bulk load (#170564)

Added: 
    

Modified: 
    flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp
    flang/test/Lower/CUDA/cuda-device-proc.cuf

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp b/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp
index d324fe102ab7a..3c86a9d7451f0 100644
--- a/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp
@@ -1508,10 +1508,8 @@ static void genTMABulkLoad(fir::FirOpBuilder &builder, mlir::Location loc,
   auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext());
   barrier = builder.createConvert(loc, llvmPtrTy, barrier);
   setAlignment(dst, kTMAAlignment);
-  dst = convertPtrToNVVMSpace(builder, loc, dst,
-                              mlir::NVVM::NVVMMemorySpace::Shared);
-  src = convertPtrToNVVMSpace(builder, loc, src,
-                              mlir::NVVM::NVVMMemorySpace::Shared);
+  dst = builder.createConvert(loc, llvmPtrTy, dst);
+  src = builder.createConvert(loc, llvmPtrTy, src);
   mlir::NVVM::InlinePtxOp::create(
       builder, loc, mlir::TypeRange{}, {dst, src, size, barrier}, {},
       "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], "

diff  --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index 20524b8f7de96..27ef8e0889627 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -543,7 +543,7 @@ end subroutine
 ! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 8 : i32
 ! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
 ! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
-! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr<3>, !llvm.ptr<3>, i32, !llvm.ptr)
+! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr, !llvm.ptr, i32, !llvm.ptr)
 ! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
 
 attributes(global) subroutine test_tma_bulk_load_c8(a, n)
@@ -563,7 +563,7 @@ end subroutine
 ! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 16 : i32
 ! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
 ! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
-! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr<3>, !llvm.ptr<3>, i32, !llvm.ptr)
+! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr, !llvm.ptr, i32, !llvm.ptr)
 ! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
 
 attributes(global) subroutine test_tma_bulk_load_i4(a, n)
@@ -583,7 +583,7 @@ end subroutine
 ! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 4 : i32
 ! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
 ! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
-! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr<3>, !llvm.ptr<3>, i32, !llvm.ptr)
+! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr, !llvm.ptr, i32, !llvm.ptr)
 ! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
 
 attributes(global) subroutine test_tma_bulk_load_i8(a, n)
@@ -603,7 +603,7 @@ end subroutine
 ! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 8 : i32
 ! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
 ! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
-! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr<3>, !llvm.ptr<3>, i32, !llvm.ptr)
+! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr, !llvm.ptr, i32, !llvm.ptr)
 ! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
 
 attributes(global) subroutine test_tma_bulk_load_r2(a, n)
@@ -623,7 +623,7 @@ end subroutine
 ! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 2 : i32
 ! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
 ! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
-! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr<3>, !llvm.ptr<3>, i32, !llvm.ptr)
+! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr, !llvm.ptr, i32, !llvm.ptr)
 ! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
 
 attributes(global) subroutine test_tma_bulk_load_r4(a, n)
@@ -643,7 +643,7 @@ end subroutine
 ! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 4 : i32
 ! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
 ! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
-! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr<3>, !llvm.ptr<3>, i32, !llvm.ptr)
+! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr, !llvm.ptr, i32, !llvm.ptr)
 ! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
 
 attributes(global) subroutine test_tma_bulk_load_r8(a, n)
@@ -663,7 +663,7 @@ end subroutine
 ! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 8 : i32
 ! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
 ! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
-! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr<3>, !llvm.ptr<3>, i32, !llvm.ptr)
+! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr, !llvm.ptr, i32, !llvm.ptr)
 ! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
 
 attributes(global) subroutine test_tma_bulk_store_c4(c, n)


        


More information about the flang-commits mailing list