[Mlir-commits] [mlir] [MLIR][NVVM] Extend TMA Bulk Copy Op (PR #140232)

Durgadoss R llvmlistbot at llvm.org
Fri May 16 05:15:58 PDT 2025


https://github.com/durga4github updated https://github.com/llvm/llvm-project/pull/140232

>From 3d0efbda07cb74d96ef3ed1ebc942b6eca7b5797 Mon Sep 17 00:00:00 2001
From: Durgadoss R <durgadossr at nvidia.com>
Date: Fri, 9 May 2025 22:27:19 +0530
Subject: [PATCH] [MLIR][NVVM] Extend TMA Bulk Copy Op

This patch extends the non-tensor TMA Bulk Copy Op
(from shared_cta to global) with an optional
byte mask operand. This mask helps in selectively
copying a particular byte to the destination.

* lit tests are added to verify the lowering to
  the intrinsics.

Signed-off-by: Durgadoss R <durgadossr at nvidia.com>
---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   | 69 +++++++++++--------
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp    | 28 ++++++++
 .../Target/LLVMIR/nvvm/tma_bulk_copy.mlir     | 12 +++-
 3 files changed, 81 insertions(+), 28 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index a8e7dcb54ac20..e07819104c781 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2599,15 +2599,37 @@ def NVVM_CpAsyncBulkSharedCTAToSharedClusterOp :
 }
 
 def NVVM_CpAsyncBulkSharedCTAToGlobalOp :
-  NVVM_Op<"cp.async.bulk.global.shared.cta"> {
+  NVVM_Op<"cp.async.bulk.global.shared.cta", [AttrSizedOperandSegments]> {
   let summary = "Async bulk copy from Shared CTA memory to Global memory";
   let description = [{
     Initiates an asynchronous copy operation from Shared CTA memory to
-    global memory.
+    global memory. The 32-bit operand `size` specifies the amount of
+    memory to be copied, in terms of number of bytes. `size` must be a
+    multiple of 16. The `l2CacheHint` operand is optional, and it is used
+    to specify cache eviction policy that may be used during the memory
+    access. The `byteMask` operand is optional. The i-th bit in the 16-bit
+    wide `byteMask` specifies whether the i-th byte of each 16-byte wide
+    chunk of source data is copied to the destination. If the bit is set,
+    the byte is copied.
+
+    Example:
+    ```mlir
+      nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size
+          : !llvm.ptr<1>, !llvm.ptr<3>
+
+      // with l2_cache_hint
+      nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size l2_cache_hint = %ch
+          : !llvm.ptr<1>, !llvm.ptr<3>
+
+      // with byte_mask
+      nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size byte_mask = %mask
+          : !llvm.ptr<1>, !llvm.ptr<3>
+
+      // with both l2_cache_hint and byte_mask
+      nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size l2_cache_hint = %ch byte_mask = %mask
+          : !llvm.ptr<1>, !llvm.ptr<3>
+    ```
 
-    The `l2CacheHint` operand is optional, and it is used to specify cache
-    eviction policy that may be used during the memory access.
-    
     [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk)
   }];
 
@@ -2615,35 +2637,28 @@ def NVVM_CpAsyncBulkSharedCTAToGlobalOp :
     LLVM_PointerGlobal:$dstMem,
     LLVM_PointerShared:$srcMem,
     I32:$size,
-    Optional<I64>:$l2CacheHint);
+    Optional<I64>:$l2CacheHint,
+    Optional<I16>:$byteMask);
 
   let assemblyFormat = [{
     $dstMem `,` $srcMem `,` $size
     (`l2_cache_hint` `=` $l2CacheHint^ )?
-    attr-dict  `:` type($dstMem) `,` type($srcMem)
+    (`byte_mask` `=` $byteMask^ )?
+    attr-dict `:` type($dstMem) `,` type($srcMem)
   }];
 
+  let extraClassDeclaration = [{
+    static void getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+                                      llvm::IRBuilderBase& builder,
+                                      llvm::Intrinsic::ID &id,
+                                      llvm::SmallVector<llvm::Value *> &args);
+  }];
   string llvmBuilder = [{
-    // Arguments to the intrinsic:
-    // dst, src, size, cache_hint,
-    // Flag for cache_hint
-    //
-    llvm::SmallVector<llvm::Value *> translatedOperands;
-    translatedOperands.push_back($dstMem);
-    translatedOperands.push_back($srcMem);
-    translatedOperands.push_back($size);
-
-    // Cachehint, if available
-    llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
-    auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
-    bool isCacheHint = op.getL2CacheHint() ? true : false;
-    translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Unused);
-
-    // Flag argument for cachehint
-    translatedOperands.push_back(builder.getInt1(isCacheHint));
-
-    createIntrinsicCall(builder,
-      llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global, translatedOperands);
+    llvm::SmallVector<llvm::Value *> args;
+    llvm::Intrinsic::ID id;
+    NVVM::CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
+        *op, moduleTranslation, builder, id, args);
+    createIntrinsicCall(builder, id, args);
   }];
 }
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 1ea3f96fa75f5..df2b0e025ae11 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1253,6 +1253,34 @@ CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
   return id;
 }
 
+void CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
+    Operation &op, LLVM::ModuleTranslation &mt,
+    llvm::IRBuilderBase &builder,
+    llvm::Intrinsic::ID &id,
+    llvm::SmallVector<llvm::Value *> &args) {
+  auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
+
+  // Fill the Intrinsic Args
+  args.push_back(mt.lookupValue(thisOp.getDstMem()));
+  args.push_back(mt.lookupValue(thisOp.getSrcMem()));
+  args.push_back(mt.lookupValue(thisOp.getSize()));
+
+  auto cacheHint = thisOp.getL2CacheHint();
+  const bool hasCacheHint = static_cast<bool>(cacheHint);
+  auto *i64Unused =
+      llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
+  args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
+  args.push_back(builder.getInt1(hasCacheHint));
+
+  if (auto byteMask = thisOp.getByteMask()) {
+    args.push_back(mt.lookupValue(byteMask));
+    id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
+    return;
+  }
+
+  id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;
+}
+
 llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
                                                                 bool isIm2Col) {
   switch (tensorDims) {
diff --git a/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir
index 39b703d9a9677..0daf24536a672 100644
--- a/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir
@@ -26,9 +26,19 @@ llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_shared_cluster(%dst : !llvm.ptr
 // CHECK-LABEL: @llvm_nvvm_cp_async_bulk_shared_cta_to_global
 llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_global(%dst : !llvm.ptr<1>, %src : !llvm.ptr<3>, %size : i32, %ch : i64) {
   // CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global(ptr addrspace(1) %[[DST:.*]], ptr addrspace(3) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 0, i1 false)
-  // CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global(ptr addrspace(1) %[[DST:.*]], ptr addrspace(3) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 %[[CH:.*]], i1 true)
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global(ptr addrspace(1) %[[DST]], ptr addrspace(3) %[[SRC]], i32 %[[SIZE]], i64 %[[CH:.*]], i1 true)
   nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size : !llvm.ptr<1>, !llvm.ptr<3>
 
   nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size l2_cache_hint = %ch : !llvm.ptr<1>, !llvm.ptr<3>
   llvm.return
 }
+
+// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_shared_cta_to_global_bytemask
+llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_global_bytemask(%dst : !llvm.ptr<1>, %src : !llvm.ptr<3>, %size : i32, %ch : i64, %mask : i16) {
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global.bytemask(ptr addrspace(1) %[[DST:.*]], ptr addrspace(3) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 0, i1 false, i16 %[[MASK:.*]])
+  // CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global.bytemask(ptr addrspace(1) %[[DST]], ptr addrspace(3) %[[SRC]], i32 %[[SIZE]], i64 %[[CH:.*]], i1 true, i16 %[[MASK]])
+  nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size byte_mask = %mask : !llvm.ptr<1>, !llvm.ptr<3>
+
+  nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size l2_cache_hint = %ch byte_mask = %mask : !llvm.ptr<1>, !llvm.ptr<3>
+  llvm.return
+}



More information about the Mlir-commits mailing list