[Mlir-commits] [mlir] [MLIR][NVVM] Update TMA Load Op (PR #156347)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 1 08:43:44 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Durgadoss R (durga4github)
<details>
<summary>Changes</summary>
This patch includes im2col and gather mode
support for the TMA Load Op. The lowering is
also updated to intrinsics except when a Predicate
is given. This completes the Blackwell additions
on this Op.
* NVVM Dialect has support for Shared::Cluster
address-space now. So, this patch also updates the
Op to use AS(7) instead of AS(3). The corresponding
inline-ptx based unit tests are also updated.
* lit tests are added for all combinations.
---
Patch is 116.74 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/156347.diff
7 Files Affected:
- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+37-29)
- (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+4)
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+88-8)
- (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (+39-65)
- (added) mlir/test/Target/LLVMIR/nvvm/tma_load_im2col.mlir (+298)
- (added) mlir/test/Target/LLVMIR/nvvm/tma_load_invalid.mlir (+37)
- (added) mlir/test/Target/LLVMIR/nvvm/tma_load_tile.mlir (+204)
``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 9528da05c9fd6..aace19ab834c3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2367,6 +2367,23 @@ def TMAStoreModeAttr : EnumAttr<NVVM_Dialect, TMAStoreMode, "tma_store_mode"> {
let assemblyFormat = "`<` $value `>`";
}
+// Num CTAs in a group participating in the TCGEN05 operation.
+// This corresponds to the "cta_group::1", "cta_group::2"
+// modifiers in the PTX instructions.
+def Tcgen05GroupCTA_1 : I32EnumAttrCase<"CTA_1", 0, "cta_1">;
+def Tcgen05GroupCTA_2 : I32EnumAttrCase<"CTA_2", 1, "cta_2">;
+
+def Tcgen05GroupKind : I32EnumAttr<"Tcgen05GroupKind",
+ "NVVM Tcgen05 group kind",
+ [Tcgen05GroupCTA_1, Tcgen05GroupCTA_2]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+def Tcgen05GroupKindAttr :
+ EnumAttr<NVVM_Dialect, Tcgen05GroupKind, "tcgen05_group"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">,
Arguments<(ins )> {
let assemblyFormat = "attr-dict";
@@ -2413,26 +2430,20 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global",
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
AttrSizedOperandSegments, NVVMRequiresSM<90>]>,
- Arguments<(ins LLVM_PointerShared:$dstMem,
- LLVM_AnyPointer:$tmaDescriptor,
+ Arguments<(ins LLVM_PointerSharedCluster:$dstMem,
+ LLVM_PointerGeneric:$tmaDescriptor,
Variadic<I32>:$coordinates,
LLVM_PointerShared:$mbar,
Variadic<I16>:$im2colOffsets,
Optional<I16>:$multicastMask,
Optional<I64>:$l2CacheHint,
+ DefaultValuedAttr<TMALoadModeAttr, "TMALoadMode::TILE">:$mode,
+ OptionalAttr<Tcgen05GroupKindAttr>:$group,
PtxPredicate:$predicate)> {
let description = [{
Initiates an asynchronous copy operation on the tensor data from global
- memory to shared memory.
-
- The Op operates has two load modes:
- 1) Tiled Mode: It's the default mode. The source multi-dimensional tensor
- layout is preserved at the destination.
-
- 2) Im2col Mode: This mode is used when `im2colOffsets` operands are present.
- the elements in the Bounding Box of the source tensor are rearranged into
- columns at the destination. In this mode, the tensor has to be at least
- 3-dimensional.
+ memory to shared::cluster memory. This Op supports all the load modes
+ specified in `TMALoadMode`.
The `multicastMask` operand is optional. When it is present, the Op copies
data from global memory to shared memory of multiple CTAs in the cluster.
@@ -2490,6 +2501,20 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
}
}];
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ bool hasIntrinsic() { return !getPredicate(); }
+
+ static mlir::NVVM::IDArgPair
+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase& builder);
+ }];
+
+ string llvmBuilder = [{
+ auto [id, args] = NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
+ *op, moduleTranslation, builder);
+ createIntrinsicCall(builder, id, args);
+ }];
}
def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp :
@@ -3314,23 +3339,6 @@ def NVVM_Breakpoint : NVVM_Op<"breakpoint"> {
//===----------------------------------------------------------------------===//
// NVVM TCGEN05 Ops
//===----------------------------------------------------------------------===//
-// Num CTAs in a group participating in the TCGEN05 operation.
-// This corresponds to the "cta_group::1", "cta_group::2"
-// modifiers in the PTX instructions.
-def Tcgen05GroupCTA_1 : I32EnumAttrCase<"CTA_1", 0, "cta_1">;
-def Tcgen05GroupCTA_2 : I32EnumAttrCase<"CTA_2", 1, "cta_2">;
-
-def Tcgen05GroupKind : I32EnumAttr<"Tcgen05GroupKind",
- "NVVM Tcgen05 group kind",
- [Tcgen05GroupCTA_1, Tcgen05GroupCTA_2]> {
- let genSpecializedAttr = 0;
- let cppNamespace = "::mlir::NVVM";
-}
-def Tcgen05GroupKindAttr :
- EnumAttr<NVVM_Dialect, Tcgen05GroupKind, "tcgen05_group"> {
- let assemblyFormat = "`<` $value `>`";
-}
-
def Tcgen05FenceBefore : I32EnumAttrCase<"BEFORE_THREAD_SYNC", 0, "before">;
def Tcgen05FenceAfter : I32EnumAttrCase<"AFTER_THREAD_SYNC", 1, "after">;
def Tcgen05FenceKind : I32EnumAttr<"Tcgen05FenceKind", "NVVM Tcgen05 fence kind",
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index ab1666a0e8e75..be913eaaa27b8 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1003,9 +1003,13 @@ struct NVGPUTmaAsyncLoadOpLowering
for (auto [index, value] : llvm::enumerate(coords)) {
coords[index] = truncToI32(b, value);
}
+
+ // TODO: Enhance the NVGPU Op for other modes too
rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
ValueRange{}, adaptor.getMulticastMask(), Value{},
+ NVVM::TMALoadMode::TILE, // default is TILE mode
+ nullptr, // default is no cta-group
adaptor.getPredicate());
return success();
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index ff6ccbaac2b35..2bd550dc3b89b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -49,7 +49,7 @@ using namespace NVVM;
//===----------------------------------------------------------------------===//
// This verifier is shared among the following Ops:
-// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
+// CpAsyncBulkTensorSharedCTAToGlobalOp (TMA Store)
// CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
bool isIm2Col,
@@ -73,13 +73,6 @@ static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
return success();
}
-LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
- size_t numIm2ColOffsets = getIm2colOffsets().size();
- bool isIm2Col = numIm2ColOffsets > 0;
- return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
- numIm2ColOffsets, getLoc());
-}
-
LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
TMAStoreMode mode = getMode();
// We lower through inline-ptx when getPredicate() is true.
@@ -157,6 +150,17 @@ LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
getMode(), getLoc());
}
+LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
+ TMALoadMode mode = getMode();
+ if (getPredicate()) {
+ if (mode != TMALoadMode::TILE && mode != TMALoadMode::IM2COL)
+ return emitError(
+ "Inline-ptx lowering supported only for Tile/Im2col mode.");
+ }
+ return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
+ getMode(), getLoc());
+}
+
LogicalResult CpAsyncBulkTensorReduceOp::verify() {
TMAStoreMode mode = getMode();
size_t dims = getCoordinates().size();
@@ -1495,6 +1499,82 @@ mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
return {id, std::move(args)};
}
+mlir::NVVM::IDArgPair
+CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ // Fill the Intrinsic Args
+ args.push_back(mt.lookupValue(thisOp.getDstMem()));
+ args.push_back(mt.lookupValue(thisOp.getMbar()));
+ args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
+
+ // Coordinates and im2col-offsets
+ for (auto v : thisOp.getCoordinates())
+ args.push_back(mt.lookupValue(v));
+ for (auto v : thisOp.getIm2colOffsets())
+ args.push_back(mt.lookupValue(v));
+
+ // MulticastMask, if available
+ mlir::Value mcMask = thisOp.getMulticastMask();
+ const bool hasMC = static_cast<bool>(mcMask);
+ llvm::Value *i16Unused =
+ llvm::ConstantInt::get(llvm::Type::getInt16Ty(mt.getLLVMContext()), 0);
+ args.push_back(hasMC ? mt.lookupValue(mcMask) : i16Unused);
+
+ // CacheHint, if available
+ mlir::Value cacheHint = thisOp.getL2CacheHint();
+ const bool hasCacheHint = static_cast<bool>(cacheHint);
+ llvm::Value *i64Unused =
+ llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
+ args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
+
+ // Flag arguments for multicast, cache-hint and CTAGroup
+ args.push_back(builder.getInt1(hasMC));
+ args.push_back(builder.getInt1(hasCacheHint));
+
+ // Flag argument CTAGroup
+ // CTA_1/2 is mapped to values 1 and 2 for the intrinsics.
+ // Hence, the +1 to getGroup().
+ const int32_t val =
+ thisOp.getGroup() ? (static_cast<int32_t>(*thisOp.getGroup()) + 1) : 0;
+ llvm::Value *cg =
+ llvm::ConstantInt::get(llvm::Type::getInt32Ty(mt.getLLVMContext()), val);
+ args.push_back(cg);
+
+ const unsigned NI = llvm::Intrinsic::not_intrinsic;
+ static constexpr llvm::Intrinsic::ID IDTable[][6] = {
+ {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d},
+ {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d},
+ {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_3d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_4d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_5d},
+ {NI, NI, NI,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_3d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_4d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_5d},
+ {NI, NI, NI, NI, NI,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d}};
+
+ static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
+ "TMALoadModes must match number of rows in IDTable");
+ size_t mode = static_cast<size_t>(thisOp.getMode());
+ size_t dim = thisOp.getCoordinates().size();
+ llvm::Intrinsic::ID id = IDTable[mode][dim];
+ if (id == llvm::Intrinsic::not_intrinsic)
+ llvm_unreachable(
+ "Invalid intrinsic for CpAsyncBulkTensorGlobalToSharedClusterOp.");
+
+ return {id, std::move(args)};
+}
+
mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 89075120d16ea..f0bcf9f3498b0 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -96,119 +96,93 @@ func.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.p
}
// CHECK-LABEL: @tma_load_3d_all
-func.func @tma_load_3d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4} ], [$5],{$6}, $7, $8;", "r,l,r,r,r,r,h,h,l"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$9 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4} ], [$5],{$6}, $7, $8;", "r,l,r,r,r,r,h,h,l,b"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p : !llvm.ptr<3>, !llvm.ptr
+func.func @tma_load_3d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<7>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$9 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4} ], [$5],{$6}, $7, $8;", "l,l,r,r,r,r,h,h,l,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr<7>, !llvm.ptr
return
}
// CHECK-LABEL: @tma_load_4d_all
-func.func @tma_load_4d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5} ], [$6],{$7,$8}, $9, $10;", "r,l,r,r,r,r,r,h,h,h,l"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0,%off1] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$11 cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5} ], [$6],{$7,$8}, $9, $10;", "r,l,r,r,r,r,r,h,h,h,l,b"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0,%off1] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p : !llvm.ptr<3>, !llvm.ptr
+func.func @tma_load_4d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<7>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$11 cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5} ], [$6],{$7,$8}, $9, $10;", "l,l,r,r,r,r,r,h,h,h,l,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0,%off1] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr<7>, !llvm.ptr
return
}
// CHECK-LABEL: @tma_load_5d_all
-func.func @tma_load_5d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %off0: i16, %off1: i16, %off2: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5,$6} ], [$7],{$8,$9,$10}, $11, $12;", "r,l,r,r,r,r,r,r,h,h,h,h,l"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] im2col[%off0,%off1,%off2] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr
- // CHECK: lvm.inline_asm has_side_effects asm_dialect = att "@$13 cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5,$6} ], [$7],{$8,$9,$10}, $11, $12;", "r,l,r,r,r,r,r,r,h,h,h,h,l,b"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] im2col[%off0,%off1,%off2] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p : !llvm.ptr<3>, !llvm.ptr
+func.func @tma_load_5d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<7>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %off0: i16, %off1: i16, %off2: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
+ // CHECK: lvm.inline_asm has_side_effects asm_dialect = att "@$13 cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5,$6} ], [$7],{$8,$9,$10}, $11, $12;", "l,l,r,r,r,r,r,r,h,h,h,h,l,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] im2col[%off0,%off1,%off2] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr<7>, !llvm.ptr
return
}
// CHECK-LABEL: @tma_load_1d
-func.func @tma_load_1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %p : i1) {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2} ], [$3];", "r,l,r,r"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0] : !llvm.ptr<3>, !llvm.ptr
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$4 cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2} ], [$3];", "r,l,r,r,b"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0] predicate=%p : !llvm.ptr<3>, !llvm.ptr
+func.func @tma_load_1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<7>, %barrier: !llvm.ptr<3>, %crd0: i32, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$4 cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2} ], [$3];", "l,l,r,r,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0] predicate=%p : !llvm.ptr<7>, !llvm.ptr
return
}
// CHECK-LABEL: @tma_load_2d
-func.func @tma_load_2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %p : i1) {
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3} ], [$4];", "r,l,r,r,r"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1] : !llvm.ptr<3>, !llvm.ptr
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3} ], [$4];", "r,l,r,r,r,b"
- nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1] predicate=%p : !llvm.ptr<3>, !llvm.ptr
+func.func @tma_load_2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<7>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %p : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3} ], [$4];", "l,l,r,r,r,b"
+ nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1] predicate=%p : !llvm.ptr<7>, !llvm.ptr
return
}
// CHECK-LABEL: @tma_load_3d
-func.func @tma_load_3d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %bar...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/156347
More information about the Mlir-commits
mailing list