[Mlir-commits] [mlir] [MLIR][NVVM] Extend TMA Bulk Copy Op (PR #140232)
Durgadoss R
llvmlistbot at llvm.org
Fri May 16 05:15:58 PDT 2025
https://github.com/durga4github updated https://github.com/llvm/llvm-project/pull/140232
>From 3d0efbda07cb74d96ef3ed1ebc942b6eca7b5797 Mon Sep 17 00:00:00 2001
From: Durgadoss R <durgadossr at nvidia.com>
Date: Fri, 9 May 2025 22:27:19 +0530
Subject: [PATCH] [MLIR][NVVM] Extend TMA Bulk Copy Op
This patch extends the non-tensor TMA Bulk Copy Op
(from shared_cta to global) with an optional
byte mask operand. This mask helps in selectively
copying a particular byte to the destination.
* lit tests are added to verify the lowering to
the intrinsics.
Signed-off-by: Durgadoss R <durgadossr at nvidia.com>
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 69 +++++++++++--------
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 28 ++++++++
.../Target/LLVMIR/nvvm/tma_bulk_copy.mlir | 12 +++-
3 files changed, 81 insertions(+), 28 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index a8e7dcb54ac20..e07819104c781 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2599,15 +2599,37 @@ def NVVM_CpAsyncBulkSharedCTAToSharedClusterOp :
}
def NVVM_CpAsyncBulkSharedCTAToGlobalOp :
- NVVM_Op<"cp.async.bulk.global.shared.cta"> {
+ NVVM_Op<"cp.async.bulk.global.shared.cta", [AttrSizedOperandSegments]> {
let summary = "Async bulk copy from Shared CTA memory to Global memory";
let description = [{
Initiates an asynchronous copy operation from Shared CTA memory to
- global memory.
+ global memory. The 32-bit operand `size` specifies the amount of
+ memory to be copied, in terms of number of bytes. `size` must be a
+ multiple of 16. The `l2CacheHint` operand is optional, and it is used
+ to specify cache eviction policy that may be used during the memory
+ access. The `byteMask` operand is optional. The i-th bit in the 16-bit
+ wide `byteMask` specifies whether the i-th byte of each 16-byte wide
+ chunk of source data is copied to the destination. If the bit is set,
+ the byte is copied.
+
+ Example:
+ ```mlir
+ nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size
+ : !llvm.ptr<1>, !llvm.ptr<3>
+
+ // with l2_cache_hint
+ nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size l2_cache_hint = %ch
+ : !llvm.ptr<1>, !llvm.ptr<3>
+
+ // with byte_mask
+ nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size byte_mask = %mask
+ : !llvm.ptr<1>, !llvm.ptr<3>
+
+ // with both l2_cache_hint and byte_mask
+ nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size l2_cache_hint = %ch byte_mask = %mask
+ : !llvm.ptr<1>, !llvm.ptr<3>
+ ```
- The `l2CacheHint` operand is optional, and it is used to specify cache
- eviction policy that may be used during the memory access.
-
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk)
}];
@@ -2615,35 +2637,28 @@ def NVVM_CpAsyncBulkSharedCTAToGlobalOp :
LLVM_PointerGlobal:$dstMem,
LLVM_PointerShared:$srcMem,
I32:$size,
- Optional<I64>:$l2CacheHint);
+ Optional<I64>:$l2CacheHint,
+ Optional<I16>:$byteMask);
let assemblyFormat = [{
$dstMem `,` $srcMem `,` $size
(`l2_cache_hint` `=` $l2CacheHint^ )?
- attr-dict `:` type($dstMem) `,` type($srcMem)
+ (`byte_mask` `=` $byteMask^ )?
+ attr-dict `:` type($dstMem) `,` type($srcMem)
}];
+ let extraClassDeclaration = [{
+ static void getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase& builder,
+ llvm::Intrinsic::ID &id,
+ llvm::SmallVector<llvm::Value *> &args);
+ }];
string llvmBuilder = [{
- // Arguments to the intrinsic:
- // dst, src, size, cache_hint,
- // Flag for cache_hint
- //
- llvm::SmallVector<llvm::Value *> translatedOperands;
- translatedOperands.push_back($dstMem);
- translatedOperands.push_back($srcMem);
- translatedOperands.push_back($size);
-
- // Cachehint, if available
- llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
- auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
- bool isCacheHint = op.getL2CacheHint() ? true : false;
- translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Unused);
-
- // Flag argument for cachehint
- translatedOperands.push_back(builder.getInt1(isCacheHint));
-
- createIntrinsicCall(builder,
- llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global, translatedOperands);
+ llvm::SmallVector<llvm::Value *> args;
+ llvm::Intrinsic::ID id;
+ NVVM::CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
+ *op, moduleTranslation, builder, id, args);
+ createIntrinsicCall(builder, id, args);
}];
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 1ea3f96fa75f5..df2b0e025ae11 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1253,6 +1253,34 @@ CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
return id;
}
+void CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder,
+ llvm::Intrinsic::ID &id,
+ llvm::SmallVector<llvm::Value *> &args) {
+ auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
+
+ // Fill the Intrinsic Args
+ args.push_back(mt.lookupValue(thisOp.getDstMem()));
+ args.push_back(mt.lookupValue(thisOp.getSrcMem()));
+ args.push_back(mt.lookupValue(thisOp.getSize()));
+
+ auto cacheHint = thisOp.getL2CacheHint();
+ const bool hasCacheHint = static_cast<bool>(cacheHint);
+ auto *i64Unused =
+ llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
+ args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
+ args.push_back(builder.getInt1(hasCacheHint));
+
+ if (auto byteMask = thisOp.getByteMask()) {
+ args.push_back(mt.lookupValue(byteMask));
+ id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask;
+ return;
+ }
+
+ id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global;
+}
+
llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
bool isIm2Col) {
switch (tensorDims) {
diff --git a/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir
index 39b703d9a9677..0daf24536a672 100644
--- a/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir
@@ -26,9 +26,19 @@ llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_shared_cluster(%dst : !llvm.ptr
// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_shared_cta_to_global
llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_global(%dst : !llvm.ptr<1>, %src : !llvm.ptr<3>, %size : i32, %ch : i64) {
// CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global(ptr addrspace(1) %[[DST:.*]], ptr addrspace(3) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 0, i1 false)
- // CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global(ptr addrspace(1) %[[DST:.*]], ptr addrspace(3) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 %[[CH:.*]], i1 true)
+ // CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global(ptr addrspace(1) %[[DST]], ptr addrspace(3) %[[SRC]], i32 %[[SIZE]], i64 %[[CH:.*]], i1 true)
nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size : !llvm.ptr<1>, !llvm.ptr<3>
nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size l2_cache_hint = %ch : !llvm.ptr<1>, !llvm.ptr<3>
llvm.return
}
+
+// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_shared_cta_to_global_bytemask
+llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_global_bytemask(%dst : !llvm.ptr<1>, %src : !llvm.ptr<3>, %size : i32, %ch : i64, %mask : i16) {
+ // CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global.bytemask(ptr addrspace(1) %[[DST:.*]], ptr addrspace(3) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 0, i1 false, i16 %[[MASK:.*]])
+ // CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.global.bytemask(ptr addrspace(1) %[[DST]], ptr addrspace(3) %[[SRC]], i32 %[[SIZE]], i64 %[[CH:.*]], i1 true, i16 %[[MASK]])
+ nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size byte_mask = %mask : !llvm.ptr<1>, !llvm.ptr<3>
+
+ nvvm.cp.async.bulk.global.shared.cta %dst, %src, %size l2_cache_hint = %ch byte_mask = %mask : !llvm.ptr<1>, !llvm.ptr<3>
+ llvm.return
+}
More information about the Mlir-commits
mailing list