[Mlir-commits] [mlir] [MLIR][NVVM] Update TMA Load Op (PR #156347)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 1 08:43:44 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Durgadoss R (durga4github)

<details>
<summary>Changes</summary>

This patch includes im2col and gather mode
support for the TMA Load Op. The lowering is
also updated to intrinsics except when a Predicate
is given. This completes the Blackwell additions
on this Op.

* NVVM Dialect has support for Shared::Cluster
   address-space now. So, this patch also updates the
   Op to use AS(7) instead of AS(3). The corresponding
   inline-ptx based unit tests are also updated.
*  lit tests are added for all combinations.


---

Patch is 116.74 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/156347.diff


7 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+37-29) 
- (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+4) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+88-8) 
- (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (+39-65) 
- (added) mlir/test/Target/LLVMIR/nvvm/tma_load_im2col.mlir (+298) 
- (added) mlir/test/Target/LLVMIR/nvvm/tma_load_invalid.mlir (+37) 
- (added) mlir/test/Target/LLVMIR/nvvm/tma_load_tile.mlir (+204) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 9528da05c9fd6..aace19ab834c3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2367,6 +2367,23 @@ def TMAStoreModeAttr : EnumAttr<NVVM_Dialect, TMAStoreMode, "tma_store_mode"> {
   let assemblyFormat = "`<` $value `>`";
 }
 
+// Num CTAs in a group participating in the TCGEN05 operation.
+// This corresponds to the "cta_group::1", "cta_group::2"
+// modifiers in the PTX instructions.
+def Tcgen05GroupCTA_1 : I32EnumAttrCase<"CTA_1", 0, "cta_1">;
+def Tcgen05GroupCTA_2 : I32EnumAttrCase<"CTA_2", 1, "cta_2">;
+
+def Tcgen05GroupKind : I32EnumAttr<"Tcgen05GroupKind",
+                            "NVVM Tcgen05 group kind",
+  [Tcgen05GroupCTA_1, Tcgen05GroupCTA_2]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def Tcgen05GroupKindAttr :
+  EnumAttr<NVVM_Dialect, Tcgen05GroupKind, "tcgen05_group"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">,
   Arguments<(ins )> {
   let assemblyFormat = "attr-dict";
@@ -2413,26 +2430,20 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
   NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global", 
   [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>, 
   AttrSizedOperandSegments, NVVMRequiresSM<90>]>,
-  Arguments<(ins  LLVM_PointerShared:$dstMem,
-                  LLVM_AnyPointer:$tmaDescriptor,
+  Arguments<(ins  LLVM_PointerSharedCluster:$dstMem,
+                  LLVM_PointerGeneric:$tmaDescriptor,
                   Variadic<I32>:$coordinates,
                   LLVM_PointerShared:$mbar,                  
                   Variadic<I16>:$im2colOffsets,
                   Optional<I16>:$multicastMask,
                   Optional<I64>:$l2CacheHint,
+                  DefaultValuedAttr<TMALoadModeAttr, "TMALoadMode::TILE">:$mode,
+                  OptionalAttr<Tcgen05GroupKindAttr>:$group,
                   PtxPredicate:$predicate)> {
   let description = [{
     Initiates an asynchronous copy operation on the tensor data from global 
-    memory to shared memory. 
-
-    The Op operates has two load modes:
-    1) Tiled Mode: It's the default mode. The source multi-dimensional tensor 
-    layout is preserved at the destination. 
-
-    2) Im2col Mode: This mode is used when `im2colOffsets` operands are present.
-    the elements in the Bounding Box of the source tensor are rearranged into
-    columns at the destination. In this mode, the tensor has to be at least 
-    3-dimensional. 
+    memory to shared::cluster memory. This Op supports all the load modes
+    specified in `TMALoadMode`.
 
     The `multicastMask` operand is optional. When it is present, the Op copies
     data from global memory to shared memory of multiple CTAs in the cluster.
@@ -2490,6 +2501,20 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
     }
   }];
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool hasIntrinsic() { return !getPredicate(); }
+
+    static mlir::NVVM::IDArgPair
+    getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+                          llvm::IRBuilderBase& builder);
+  }];
+
+  string llvmBuilder = [{
+    auto [id, args] = NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
+                      *op, moduleTranslation, builder);
+    createIntrinsicCall(builder, id, args);
+  }];
 }
 
 def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp : 
@@ -3314,23 +3339,6 @@ def NVVM_Breakpoint : NVVM_Op<"breakpoint"> {
 //===----------------------------------------------------------------------===//
 // NVVM TCGEN05 Ops
 //===----------------------------------------------------------------------===//
-// Num CTAs in a group participating in the TCGEN05 operation.
-// This corresponds to the "cta_group::1", "cta_group::2"
-// modifiers in the PTX instructions.
-def Tcgen05GroupCTA_1 : I32EnumAttrCase<"CTA_1", 0, "cta_1">;
-def Tcgen05GroupCTA_2 : I32EnumAttrCase<"CTA_2", 1, "cta_2">;
-
-def Tcgen05GroupKind : I32EnumAttr<"Tcgen05GroupKind",
-                            "NVVM Tcgen05 group kind",
-  [Tcgen05GroupCTA_1, Tcgen05GroupCTA_2]> {
-  let genSpecializedAttr = 0;
-  let cppNamespace = "::mlir::NVVM";
-}
-def Tcgen05GroupKindAttr :
-  EnumAttr<NVVM_Dialect, Tcgen05GroupKind, "tcgen05_group"> {
-  let assemblyFormat = "`<` $value `>`";
-}
-
 def Tcgen05FenceBefore : I32EnumAttrCase<"BEFORE_THREAD_SYNC", 0, "before">;
 def Tcgen05FenceAfter  : I32EnumAttrCase<"AFTER_THREAD_SYNC",  1, "after">;
 def Tcgen05FenceKind : I32EnumAttr<"Tcgen05FenceKind", "NVVM Tcgen05 fence kind",
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index ab1666a0e8e75..be913eaaa27b8 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1003,9 +1003,13 @@ struct NVGPUTmaAsyncLoadOpLowering
     for (auto [index, value] : llvm::enumerate(coords)) {
       coords[index] = truncToI32(b, value);
     }
+
+    // TODO: Enhance the NVGPU Op for other modes too
     rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
         op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
         ValueRange{}, adaptor.getMulticastMask(), Value{},
+        NVVM::TMALoadMode::TILE, // default is TILE mode
+        nullptr,                 // default is no cta-group
         adaptor.getPredicate());
     return success();
   }
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index ff6ccbaac2b35..2bd550dc3b89b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -49,7 +49,7 @@ using namespace NVVM;
 //===----------------------------------------------------------------------===//
 
 // This verifier is shared among the following Ops:
-// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
+// CpAsyncBulkTensorSharedCTAToGlobalOp (TMA Store)
 // CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
 static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
                                                      bool isIm2Col,
@@ -73,13 +73,6 @@ static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
   return success();
 }
 
-LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
-  size_t numIm2ColOffsets = getIm2colOffsets().size();
-  bool isIm2Col = numIm2ColOffsets > 0;
-  return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
-                                         numIm2ColOffsets, getLoc());
-}
-
 LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
   TMAStoreMode mode = getMode();
   // We lower through inline-ptx when getPredicate() is true.
@@ -157,6 +150,17 @@ LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
                              getMode(), getLoc());
 }
 
+LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
+  TMALoadMode mode = getMode();
+  if (getPredicate()) {
+    if (mode != TMALoadMode::TILE && mode != TMALoadMode::IM2COL)
+      return emitError(
+          "Inline-ptx lowering supported only for Tile/Im2col mode.");
+  }
+  return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
+                             getMode(), getLoc());
+}
+
 LogicalResult CpAsyncBulkTensorReduceOp::verify() {
   TMAStoreMode mode = getMode();
   size_t dims = getCoordinates().size();
@@ -1495,6 +1499,82 @@ mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
   return {id, std::move(args)};
 }
 
+mlir::NVVM::IDArgPair
+CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
+    Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+  auto thisOp = cast<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(op);
+  llvm::SmallVector<llvm::Value *> args;
+
+  // Fill the Intrinsic Args
+  args.push_back(mt.lookupValue(thisOp.getDstMem()));
+  args.push_back(mt.lookupValue(thisOp.getMbar()));
+  args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
+
+  // Coordinates and im2col-offsets
+  for (auto v : thisOp.getCoordinates())
+    args.push_back(mt.lookupValue(v));
+  for (auto v : thisOp.getIm2colOffsets())
+    args.push_back(mt.lookupValue(v));
+
+  // MulticastMask, if available
+  mlir::Value mcMask = thisOp.getMulticastMask();
+  const bool hasMC = static_cast<bool>(mcMask);
+  llvm::Value *i16Unused =
+      llvm::ConstantInt::get(llvm::Type::getInt16Ty(mt.getLLVMContext()), 0);
+  args.push_back(hasMC ? mt.lookupValue(mcMask) : i16Unused);
+
+  // CacheHint, if available
+  mlir::Value cacheHint = thisOp.getL2CacheHint();
+  const bool hasCacheHint = static_cast<bool>(cacheHint);
+  llvm::Value *i64Unused =
+      llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
+  args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
+
+  // Flag arguments for multicast, cache-hint and CTAGroup
+  args.push_back(builder.getInt1(hasMC));
+  args.push_back(builder.getInt1(hasCacheHint));
+
+  // Flag argument CTAGroup
+  // CTA_1/2 is mapped to values 1 and 2 for the intrinsics.
+  // Hence, the +1 to getGroup().
+  const int32_t val =
+      thisOp.getGroup() ? (static_cast<int32_t>(*thisOp.getGroup()) + 1) : 0;
+  llvm::Value *cg =
+      llvm::ConstantInt::get(llvm::Type::getInt32Ty(mt.getLLVMContext()), val);
+  args.push_back(cg);
+
+  const unsigned NI = llvm::Intrinsic::not_intrinsic;
+  static constexpr llvm::Intrinsic::ID IDTable[][6] = {
+      {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d},
+      {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d},
+      {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_3d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_4d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_5d},
+      {NI, NI, NI,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_3d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_4d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_5d},
+      {NI, NI, NI, NI, NI,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d}};
+
+  static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
+                "TMALoadModes must match number of rows in IDTable");
+  size_t mode = static_cast<size_t>(thisOp.getMode());
+  size_t dim = thisOp.getCoordinates().size();
+  llvm::Intrinsic::ID id = IDTable[mode][dim];
+  if (id == llvm::Intrinsic::not_intrinsic)
+    llvm_unreachable(
+        "Invalid intrinsic for CpAsyncBulkTensorGlobalToSharedClusterOp.");
+
+  return {id, std::move(args)};
+}
+
 mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
     Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
   auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 89075120d16ea..f0bcf9f3498b0 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -96,119 +96,93 @@ func.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.p
 }
 
 // CHECK-LABEL: @tma_load_3d_all
-func.func @tma_load_3d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4} ], [$5],{$6}, $7, $8;", "r,l,r,r,r,r,h,h,l"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr  
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$9 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4} ], [$5],{$6}, $7, $8;", "r,l,r,r,r,r,h,h,l,b"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p : !llvm.ptr<3>, !llvm.ptr
+func.func @tma_load_3d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<7>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$9 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4} ], [$5],{$6}, $7, $8;", "l,l,r,r,r,r,h,h,l,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr<7>, !llvm.ptr
   return
 }
 
 // CHECK-LABEL: @tma_load_4d_all
-func.func @tma_load_4d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5} ], [$6],{$7,$8}, $9, $10;", "r,l,r,r,r,r,r,h,h,h,l"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0,%off1] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr  
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$11 cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5} ], [$6],{$7,$8}, $9, $10;", "r,l,r,r,r,r,r,h,h,h,l,b"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0,%off1] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p : !llvm.ptr<3>, !llvm.ptr
+func.func @tma_load_4d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<7>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$11 cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5} ], [$6],{$7,$8}, $9, $10;", "l,l,r,r,r,r,r,h,h,h,l,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0,%off1] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr<7>, !llvm.ptr
   return
 }
 
 // CHECK-LABEL: @tma_load_5d_all
-func.func @tma_load_5d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %off0: i16, %off1: i16, %off2: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5,$6} ], [$7],{$8,$9,$10}, $11, $12;", "r,l,r,r,r,r,r,r,h,h,h,h,l"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] im2col[%off0,%off1,%off2] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr  
-  // CHECK: lvm.inline_asm has_side_effects asm_dialect = att "@$13 cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5,$6} ], [$7],{$8,$9,$10}, $11, $12;", "r,l,r,r,r,r,r,r,h,h,h,h,l,b"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] im2col[%off0,%off1,%off2] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p : !llvm.ptr<3>, !llvm.ptr
+func.func @tma_load_5d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<7>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %off0: i16, %off1: i16, %off2: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
+  // CHECK: lvm.inline_asm has_side_effects asm_dialect = att "@$13 cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5,$6} ], [$7],{$8,$9,$10}, $11, $12;", "l,l,r,r,r,r,r,r,h,h,h,h,l,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] im2col[%off0,%off1,%off2] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr<7>, !llvm.ptr
   return
 }
 
 // CHECK-LABEL: @tma_load_1d
-func.func @tma_load_1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %p : i1) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2} ], [$3];", "r,l,r,r"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0] : !llvm.ptr<3>, !llvm.ptr
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$4 cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2} ], [$3];", "r,l,r,r,b"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0] predicate=%p : !llvm.ptr<3>, !llvm.ptr
+func.func @tma_load_1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<7>, %barrier: !llvm.ptr<3>, %crd0: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$4 cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2} ], [$3];", "l,l,r,r,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0] predicate=%p : !llvm.ptr<7>, !llvm.ptr
   return
 }
 
 // CHECK-LABEL: @tma_load_2d
-func.func @tma_load_2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %p : i1) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3} ], [$4];", "r,l,r,r,r"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1] : !llvm.ptr<3>, !llvm.ptr
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3} ], [$4];", "r,l,r,r,r,b"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1] predicate=%p  : !llvm.ptr<3>, !llvm.ptr
+func.func @tma_load_2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<7>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3} ], [$4];", "l,l,r,r,r,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1] predicate=%p  : !llvm.ptr<7>, !llvm.ptr
   return
 }
 
 // CHECK-LABEL: @tma_load_3d
-func.func @tma_load_3d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %bar...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/156347


More information about the Mlir-commits mailing list