[Mlir-commits] [mlir] 95aa70c - [MLIR][NVVM] Add support for shared::cta destination (#168056)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 17 00:28:27 PST 2025


Author: Durgadoss R
Date: 2025-11-17T13:58:21+05:30
New Revision: 95aa70cf209ab317d908e38630c92bf62d149247

URL: https://github.com/llvm/llvm-project/commit/95aa70cf209ab317d908e38630c92bf62d149247
DIFF: https://github.com/llvm/llvm-project/commit/95aa70cf209ab317d908e38630c92bf62d149247.diff

LOG: [MLIR][NVVM] Add support for shared::cta destination (#168056)

This patch adds support for shared::cta as destination space in
the TMA non-tensor copy Op (from global to shared::cta).

* Appropriate verifier checks are added.
* Unit tests are added to verify the lowering.

The related intrinsic changes were merged through PR #167508.

Signed-off-by: Durgadoss R <durgadossr at nvidia.com>

Added: 
    mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy_invalid.mlir

Modified: 
    flang/test/Lower/CUDA/cuda-device-proc.cuf
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
    mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir

Removed: 
    


################################################################################
diff  --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index 69fefcf972065..434322ea22265 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -477,7 +477,7 @@ end subroutine
 ! CHECK: %[[DST_7:.*]] = llvm.addrspacecast %[[DST_PTR]] : !llvm.ptr to !llvm.ptr<7>
 ! CHECK: %[[SRC_PTR:.*]] = fir.convert %[[SRC]] : (!fir.ref<f64>) -> !llvm.ptr
 ! CHECK: %[[SRC_3:.*]] = llvm.addrspacecast %[[SRC_PTR]] : !llvm.ptr to !llvm.ptr<1>
-! CHECK: nvvm.cp.async.bulk.shared.cluster.global %[[DST_7]], %[[SRC_3]], %[[BARRIER_3]], %[[COUNT_LOAD]] : <7>, <1>
+! CHECK: nvvm.cp.async.bulk.shared.cluster.global %[[DST_7]], %[[SRC_3]], %[[BARRIER_3]], %[[COUNT_LOAD]] : !llvm.ptr<7>, <1>
 
 attributes(global) subroutine test_bulk_s2g(a)
   real(8), device :: a(*)

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 995ade5c9b033..d4ef5104d3c1f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3342,16 +3342,17 @@ def NVVM_CpAsyncBulkTensorReduceOp :
 
 def NVVM_CpAsyncBulkGlobalToSharedClusterOp :
   NVVM_Op<"cp.async.bulk.shared.cluster.global", [AttrSizedOperandSegments]> {
-  let summary = "Async bulk copy from global memory to Shared cluster memory";
+  let summary = "Async bulk copy from global to Shared {cta or cluster} memory";
   let description = [{
-    Initiates an asynchronous copy operation from global memory to cluster's
-    shared memory.
+    Initiates an asynchronous copy operation from global memory to shared
+    memory or shared_cluster memory.
 
-    The `multicastMask` operand is optional. When it is present, the Op copies
+    The `multicastMask` operand is optional and can be used only when the
+    destination is shared::cluster memory. When it is present, this Op copies
     data from global memory to shared memory of multiple CTAs in the cluster.
     Operand `multicastMask` specifies the destination CTAs in the cluster such
     that each bit position in the 16-bit `multicastMask` operand corresponds to
-    the `nvvm.read.ptx.sreg.ctaid` of the destination CTA.
+    the `nvvm.read.ptx.sreg.ctaid` of the destination CTA. 
 
     The `l2CacheHint` operand is optional, and it is used to specify cache
     eviction policy that may be used during the memory access.
@@ -3360,7 +3361,7 @@ def NVVM_CpAsyncBulkGlobalToSharedClusterOp :
   }];
 
   let arguments = (ins
-    LLVM_PointerSharedCluster:$dstMem,
+    AnyTypeOf<[LLVM_PointerShared, LLVM_PointerSharedCluster]>:$dstMem,
     LLVM_PointerGlobal:$srcMem,
     LLVM_PointerShared:$mbar,
     I32:$size,
@@ -3374,6 +3375,8 @@ def NVVM_CpAsyncBulkGlobalToSharedClusterOp :
     attr-dict  `:` type($dstMem) `,` type($srcMem)
   }];
 
+  let hasVerifier = 1;
+
   let extraClassDeclaration = [{
     static mlir::NVVM::IDArgPair
     getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 0f7b3638fb30d..7ac427dbe3941 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -212,6 +212,14 @@ LogicalResult CpAsyncBulkTensorReduceOp::verify() {
   return success();
 }
 
+LogicalResult CpAsyncBulkGlobalToSharedClusterOp::verify() {
+  bool isSharedCTA = isPtrInSharedCTASpace(getDstMem());
+  if (isSharedCTA && getMulticastMask())
+    return emitError("Multicast is not supported with shared::cta mode.");
+
+  return success();
+}
+
 LogicalResult ConvertFloatToTF32Op::verify() {
   using RndMode = NVVM::FPRoundingMode;
   switch (getRnd()) {
@@ -1980,11 +1988,15 @@ mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
   args.push_back(mt.lookupValue(thisOp.getSrcMem()));
   args.push_back(mt.lookupValue(thisOp.getSize()));
 
-  // Multicast mask, if available.
+  // Multicast mask for shared::cluster only, if available.
   mlir::Value multicastMask = thisOp.getMulticastMask();
   const bool hasMulticastMask = static_cast<bool>(multicastMask);
-  llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
-  args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask) : i16Unused);
+  const bool isSharedCTA = isPtrInSharedCTASpace(thisOp.getDstMem());
+  if (!isSharedCTA) {
+    llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0);
+    args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask)
+                                    : i16Unused);
+  }
 
   // Cache hint, if available.
   mlir::Value cacheHint = thisOp.getL2CacheHint();
@@ -1993,11 +2005,14 @@ mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
   args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
 
   // Flag arguments for multicast and cachehint.
-  args.push_back(builder.getInt1(hasMulticastMask));
+  if (!isSharedCTA)
+    args.push_back(builder.getInt1(hasMulticastMask));
   args.push_back(builder.getInt1(hasCacheHint));
 
   llvm::Intrinsic::ID id =
-      llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
+      isSharedCTA
+          ? llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cta
+          : llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster;
 
   return {id, std::move(args)};
 }

diff  --git a/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir
index 0daf24536a672..240fab5b63908 100644
--- a/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir
@@ -16,6 +16,17 @@ llvm.func @llvm_nvvm_cp_async_bulk_global_to_shared_cluster(%dst : !llvm.ptr<7>,
   llvm.return
 }
 
+// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_global_to_shared_cta
+llvm.func @llvm_nvvm_cp_async_bulk_global_to_shared_cta(%dst : !llvm.ptr<3>, %src : !llvm.ptr<1>, %mbar : !llvm.ptr<3>, %size : i32, %ch : i64) {
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.global.to.shared.cta(ptr addrspace(3) %[[DST:.*]], ptr addrspace(3) %[[MBAR:.*]], ptr addrspace(1) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 0, i1 false)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.global.to.shared.cta(ptr addrspace(3) %[[DST]], ptr addrspace(3) %[[MBAR]], ptr addrspace(1) %[[SRC]], i32 %[[SIZE]], i64 %[[CH:.*]], i1 true)
+  nvvm.cp.async.bulk.shared.cluster.global %dst, %src, %mbar, %size : !llvm.ptr<3>, !llvm.ptr<1>
+
+  nvvm.cp.async.bulk.shared.cluster.global %dst, %src, %mbar, %size l2_cache_hint = %ch : !llvm.ptr<3>, !llvm.ptr<1>
+
+  llvm.return
+}
+
 // CHECK-LABEL: @llvm_nvvm_cp_async_bulk_shared_cta_to_shared_cluster
 llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_shared_cluster(%dst : !llvm.ptr<7>, %src : !llvm.ptr<3>, %mbar : !llvm.ptr<3>, %size : i32) {
   // CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.cluster(ptr addrspace(7) %0, ptr addrspace(3) %2, ptr addrspace(3) %1, i32 %3)

diff  --git a/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy_invalid.mlir
new file mode 100644
index 0000000000000..d762ff3ff1e76
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy_invalid.mlir
@@ -0,0 +1,8 @@
+// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
+
+llvm.func @tma_bulk_copy_g2s_mc(%src : !llvm.ptr<1>, %dest : !llvm.ptr<3>, %bar : !llvm.ptr<3>, %size : i32, %ctamask : i16) {
+  // expected-error @below {{Multicast is not supported with shared::cta mode.}}
+  nvvm.cp.async.bulk.shared.cluster.global %dest, %src, %bar, %size multicast_mask = %ctamask : !llvm.ptr<3>, !llvm.ptr<1>
+
+  llvm.return
+}


        


More information about the Mlir-commits mailing list