[Mlir-commits] [mlir] [MLIR][NVVM] Update TMA tensor prefetch Op (PR #153464)

Durgadoss R llvmlistbot at llvm.org
Wed Aug 20 06:35:10 PDT 2025


https://github.com/durga4github updated https://github.com/llvm/llvm-project/pull/153464

>From 0e2866770b43382c4c0ec644fe5a260602d5fc11 Mon Sep 17 00:00:00 2001
From: Durgadoss R <durgadossr at nvidia.com>
Date: Tue, 12 Aug 2025 19:06:16 +0530
Subject: [PATCH] [MLIR][NVVM] Update TMA tensor prefetch Op

This patch updates the TMA Tensor prefetch Op
with support for im2col_w/w128 and tile_gather4 modes.
This completes support for all modes available in
Blackwell.
* lit tests are added for all possible combinations.
* The invalid tests are moved to a separate file
  with more coverage.

Signed-off-by: Durgadoss R <durgadossr at nvidia.com>
---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   |  94 +++++++++-----
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp    | 118 ++++++++++++++----
 .../test/Target/LLVMIR/nvvm/tma_prefetch.mlir | 117 ++++++++++++-----
 .../LLVMIR/nvvm/tma_prefetch_invalid.mlir     |  56 +++++++++
 mlir/test/Target/LLVMIR/nvvmir-invalid.mlir   |  26 ----
 5 files changed, 292 insertions(+), 119 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/nvvm/tma_prefetch_invalid.mlir

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index f9cd58de8915f..40e264202a68c 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2256,6 +2256,56 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
 // NVVM TMA Ops
 //===----------------------------------------------------------------------===//
 
+// List of modes supported for TMA Load and Prefetch Ops
+def TMALoadModeTile   : I32EnumAttrCase<"TILE", 0, "tile">;
+def TMALoadModeIm2Col : I32EnumAttrCase<"IM2COL", 1, "im2col">;
+def TMALoadModeIm2ColW : I32EnumAttrCase<"IM2COL_W", 2, "im2col_w">;
+def TMALoadModeIm2ColW128 : I32EnumAttrCase<"IM2COL_W_128", 3, "im2col_w_128">;
+def TMALoadModeTileGather4 : I32EnumAttrCase<"TILE_GATHER4", 4, "tile_gather4">;
+
+def TMALoadMode : I32EnumAttr<"TMALoadMode", "NVVM TMA Load Mode",
+    [TMALoadModeTile, TMALoadModeIm2Col,
+     TMALoadModeIm2ColW, TMALoadModeIm2ColW128,
+     TMALoadModeTileGather4]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def TMALoadModeAttr : EnumAttr<NVVM_Dialect, TMALoadMode, "tma_load_mode"> {
+  let summary = "List of Load-Modes supported for TMA Tensor Ops";
+  let description = [{
+    TMA Tensor Ops support the following modes, when copying data from
+    global memory to shared memory (i.e. load):
+
+    Tile Mode: It's the default mode. The source multi-dimensional tensor
+    layout is preserved at the destination.
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-tiled-mode)
+
+    Im2col Mode: This mode is used when `im2colOffsets` operands are present.
+    The elements in the Bounding Box of the source tensor are rearranged into
+    columns at the destination. In this mode, the tensor has to be at least
+    3-dimensional. The number of `im2colOffsets` is `dims - 2` where `dims`
+    is the dimension of the tensor.
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-im2col-mode)
+
+    Im2col_W Mode: This mode is similar to Im2Col mode with the restriction that
+    elements are accessed across the W dimension only. The number of `im2colOffsets`
+    are always two, referred as `wHalo` and `wOffset`.
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-im2col-w-w128-modes)
+
+    Im2col_W_128 Mode: This mode is similar to Im2Col_W mode with the number of
+    elements accessed across the W dimension is always 128 only.
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-im2col-w-w128-modes)
+
+    Tile_Gather4 Mode: This mode is similar to Tile mode but works only on 2D tensor.
+    In gather4 mode, four rows in the source 2D tensor are combined to form a single
+    2D tensor at the destination. This mode requires five co-ordinates. The first one
+    represents the column-index followed by four row indices.
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-tiled-scatter4-gather4-modes)
+  }];
+
+  let assemblyFormat = "`<` $value `>`";
+}
+
 def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">,
   Arguments<(ins )> {
   let assemblyFormat = "attr-dict";
@@ -2524,23 +2574,16 @@ def NVVM_CpAsyncBulkPrefetchOp : NVVM_Op<"cp.async.bulk.prefetch"> {
 def NVVM_CpAsyncBulkTensorPrefetchOp :
   NVVM_Op<"cp.async.bulk.tensor.prefetch", [AttrSizedOperandSegments]> {
   let arguments = (ins
-    LLVM_AnyPointer:$tmaDescriptor,
+    LLVM_PointerGeneric:$tmaDescriptor,
     Variadic<I32>:$coordinates,
     Variadic<I16>:$im2colOffsets,
+    DefaultValuedAttr<TMALoadModeAttr, "TMALoadMode::TILE">:$mode,
     Optional<I64>:$l2CacheHint);
 
   let description = [{
     Initiates an asynchronous prefetch operation on the tensor data from global
-    memory to L2 cache.
-
-    The Op has two modes:
-    1) Tiled Mode: It's the default mode. The source multi-dimensional tensor
-    layout is preserved at the destination.
-
-    2) Im2col Mode: This mode is used when `im2colOffsets` operands are present.
-    the elements in the Bounding Box of the source tensor are rearranged into
-    columns at the destination. In this mode, the tensor has to be at least
-    3-dimensional.
+    memory to L2 cache. This Op supports all the load modes specified in
+    `TMALoadMode`.
 
     The `l2CacheHint` operand is optional, and it is used to specify cache
     eviction policy that may be used during the memory access.
@@ -2557,34 +2600,17 @@ def NVVM_CpAsyncBulkTensorPrefetchOp :
   }];
 
   let extraClassDeclaration = [{
-    static llvm::Intrinsic::ID getIntrinsicID(int tensorDims, bool isIm2Col);
+    static mlir::NVVM::IDArgPair
+    getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+                          llvm::IRBuilderBase& builder);
   }];
 
   let hasVerifier = 1;
 
   string llvmBuilder = [{
-    // Arguments to the intrinsic:
-    // tmaDesc, tensorDims, im2colOffsets
-    // cache_hint(if applicable) and flag(boolean)
-    llvm::SmallVector<llvm::Value *> translatedOperands;
-    translatedOperands.push_back($tmaDescriptor);
-
-    for (auto v : op.getCoordinates())
-      translatedOperands.push_back(moduleTranslation.lookupValue(v));
-
-    for (auto v : op.getIm2colOffsets())
-      translatedOperands.push_back(moduleTranslation.lookupValue(v));
-
-    llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
-    auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
-
-    bool isCacheHint = op.getL2CacheHint() ? true : false;
-    translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Unused);
-    translatedOperands.push_back(builder.getInt1(isCacheHint));
-
-    auto intId = NVVM::CpAsyncBulkTensorPrefetchOp::getIntrinsicID(
-        op.getCoordinates().size(), op.getIm2colOffsets().size() > 0);
-    createIntrinsicCall(builder, intId, translatedOperands);
+    auto [id, args] = NVVM::CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
+                      *op, moduleTranslation, builder);
+    createIntrinsicCall(builder, id, args);
   }];
 }
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index dbcc738b4419f..f8ac6fc413d77 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -50,7 +50,6 @@ using namespace NVVM;
 
 // This verifier is shared among the following Ops:
 // CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
-// CpAsyncBulkTensorPrefetchOp (TMA Prefetch)
 // CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
 static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
                                                      bool isIm2Col,
@@ -98,11 +97,47 @@ LogicalResult CpAsyncOp::verify() {
   return success();
 }
 
+// This verify params can be shared across TMA Load and Prefetch Ops.
+static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff,
+                                         TMALoadMode mode, Location loc) {
+  if (tensorDims < 1 || tensorDims > 5)
+    return emitError(loc, "expects coordinates between 1 to 5 dimension");
+
+  auto checkTMALoadParams = [&](TMALoadMode mode, bool isIm2col,
+                                size_t expectedIm2colOff) -> LogicalResult {
+    if (isIm2col && (tensorDims < 3))
+      return emitError(loc)
+             << "to use " << stringifyEnum(mode)
+             << " mode, the tensor has to be at least 3-dimensional";
+
+    if (numIm2colOff != expectedIm2colOff)
+      return emitError(loc) << " im2col offsets expected " << expectedIm2colOff
+                            << " (provided " << numIm2colOff << ")";
+
+    return success();
+  };
+
+  switch (mode) {
+  case TMALoadMode::TILE:
+    return checkTMALoadParams(mode, false, 0);
+  case TMALoadMode::IM2COL:
+    return checkTMALoadParams(mode, true, tensorDims - 2);
+  case TMALoadMode::IM2COL_W:
+  case TMALoadMode::IM2COL_W_128:
+    return checkTMALoadParams(mode, true, 2);
+  case TMALoadMode::TILE_GATHER4:
+    return (tensorDims == 5)
+               ? checkTMALoadParams(mode, false, 0)
+               : emitError(loc, "Gather4 mode expects 5 coordinates");
+  default:
+    return emitError(loc, "Invalid LoadMode in CpAsyncBulkTensorPrefetchOp.");
+  }
+  return success();
+}
+
 LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
-  size_t numIm2ColOffsets = getIm2colOffsets().size();
-  bool isIm2Col = numIm2ColOffsets > 0;
-  return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
-                                         numIm2ColOffsets, getLoc());
+  return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
+                             getMode(), getLoc());
 }
 
 LogicalResult CpAsyncBulkTensorReduceOp::verify() {
@@ -1433,28 +1468,57 @@ mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
   return {id, std::move(args)};
 }
 
-llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
-                                                                bool isIm2Col) {
-  switch (tensorDims) {
-  case 1:
-    return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
-  case 2:
-    return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
-  case 3:
-    return isIm2Col
-               ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
-               : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
-  case 4:
-    return isIm2Col
-               ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
-               : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
-  case 5:
-    return isIm2Col
-               ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
-               : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
-  default:
-    llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp.");
-  }
+mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
+    Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+  auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
+  llvm::SmallVector<llvm::Value *> args;
+
+  // Fill the Intrinsic Args
+  args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
+
+  for (auto v : thisOp.getCoordinates())
+    args.push_back(mt.lookupValue(v));
+  for (auto v : thisOp.getIm2colOffsets())
+    args.push_back(mt.lookupValue(v));
+
+  mlir::Value cacheHint = thisOp.getL2CacheHint();
+  const bool hasCacheHint = static_cast<bool>(cacheHint);
+  llvm::Value *i64Unused =
+      llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
+  args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
+  args.push_back(builder.getInt1(hasCacheHint));
+
+  const unsigned NI = llvm::Intrinsic::not_intrinsic;
+  static constexpr llvm::Intrinsic::ID IDTable[][6] = {
+      {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d},
+      {NI, NI, NI,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d},
+      {NI, NI, NI,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d},
+      {NI, NI, NI,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d},
+      {NI, NI, NI, NI, NI,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}};
+
+  static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
+                "TMALoadModes must match number of rows in IDTable");
+  size_t mode = static_cast<size_t>(thisOp.getMode());
+  size_t dim = thisOp.getCoordinates().size();
+  llvm::Intrinsic::ID id = IDTable[mode][dim];
+  if (id == llvm::Intrinsic::not_intrinsic)
+    llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
+
+  return {id, std::move(args)};
 }
 
 #define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode)                        \
diff --git a/mlir/test/Target/LLVMIR/nvvm/tma_prefetch.mlir b/mlir/test/Target/LLVMIR/nvvm/tma_prefetch.mlir
index bfd952636ffbe..536b52b034db8 100644
--- a/mlir/test/Target/LLVMIR/nvvm/tma_prefetch.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/tma_prefetch.mlir
@@ -1,70 +1,123 @@
 // RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
 
-// CHECK-LABEL: @tma_bulk_prefetch
 llvm.func @tma_bulk_prefetch(%src : !llvm.ptr<1>, %size : i32, %ch : i64) {
-  // CHECK: call void @llvm.nvvm.cp.async.bulk.prefetch.L2(ptr addrspace(1) %{{.*}}, i32 %{{.*}}, i64 0, i1 false)
-  // CHECK: call void @llvm.nvvm.cp.async.bulk.prefetch.L2(ptr addrspace(1) %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
+  // CHECK-LABEL: define void @tma_bulk_prefetch(ptr addrspace(1) %0, i32 %1, i64 %2) {
+  // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.prefetch.L2(ptr addrspace(1) %0, i32 %1, i64 0, i1 false)
+  // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.prefetch.L2(ptr addrspace(1) %0, i32 %1, i64 %2, i1 true)
+  // CHECK-NEXT: ret void
+  // CHECK-NEXT: }
   nvvm.cp.async.bulk.prefetch %src, %size : !llvm.ptr<1>
   nvvm.cp.async.bulk.prefetch %src, %size l2_cache_hint = %ch : !llvm.ptr<1>
   llvm.return
 }
 
-// CHECK-LABEL: @tma_prefetch_1d
 llvm.func @tma_prefetch_1d(%tma_desc : !llvm.ptr, %d0 : i32, %ch : i64) {
-  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %{{.*}}, i64 0, i1 false)
-  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %{{.*}}, i64 %{{.*}}, i1 true)
+  // CHECK-LABEL: define void @tma_prefetch_1d(ptr %0, i32 %1, i64 %2) {
+  // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %1, i64 0, i1 false)
+  // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %1, i64 %2, i1 true)
+  // CHECK-NEXT: ret void
+  // CHECK-NEXT: }
   nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0] : !llvm.ptr
   nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0] l2_cache_hint = %ch : !llvm.ptr
   llvm.return
 }
 
-// CHECK-LABEL: @tma_prefetch_2d
 llvm.func @tma_prefetch_2d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %ch : i64) {
-  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i64 0, i1 false)
-  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
-  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] : !llvm.ptr
+  // CHECK-LABEL: define void @tma_prefetch_2d(ptr %0, i32 %1, i32 %2, i64 %3) {
+  // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %1, i32 %2, i64 0, i1 false)
+  // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %1, i32 %2, i64 %3, i1 true)
+  // CHECK-NEXT: ret void
+  // CHECK-NEXT: }
+  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] {mode = #nvvm.tma_load_mode<tile>} : !llvm.ptr
   nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] l2_cache_hint = %ch : !llvm.ptr
   llvm.return
 }
 
-// CHECK-LABEL: @tma_prefetch_3d
-llvm.func @tma_prefetch_3d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %off0 : i16, %ch : i64) {
-  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 0, i1 false)
-  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
+llvm.func @tma_prefetch_3d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %off0 : i16, %off1 : i16, %ch : i64) {
+  // CHECK-LABEL: define void @tma_prefetch_3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i16 %5, i64 %6) {
+  // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %1, i32 %2, i32 %3, i64 0, i1 false)
+  // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %1, i32 %2, i32 %3, i64 %6, i1 true)
+  // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i64 0, i1 false)
+  // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i64 %6, i1 true)
+  // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i16 %5, i64 0, i1 false)
+  // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i16 %5, i64 %6, i1 true)
+  // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.128.3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i16 %5, i64 0, i1 false)
+  // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.128.3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i16 %5, i64 %6, i1 true)
+  // CHECK-NEXT: ret void
+  // CHECK-NEXT: }
   nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] : !llvm.ptr
   nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] l2_cache_hint = %ch : !llvm.ptr
 
-  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i64 0, i1 false)
-  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true)
-  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] : !llvm.ptr
-  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] l2_cache_hint = %ch : !llvm.ptr
+  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
+  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
+
+  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col_w>} : !llvm.ptr
+  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0, %off1] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col_w>} : !llvm.ptr
+
+  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col_w_128>} : !llvm.ptr
+  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0, %off1] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col_w_128>} : !llvm.ptr
   llvm.return
 }
 
-// CHECK-LABEL: @tma_prefetch_4d
 llvm.func @tma_prefetch_4d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %off0 : i16, %off1 : i16, %ch : i64) {
-  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 0, i1 false)
-  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
+  // CHECK-LABEL: define void @tma_prefetch_4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 %7) {
+  // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i64 0, i1 false)
+  // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i64 %7, i1 true)
+  // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 0, i1 false)
+  // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 %7, i1 true)
+  // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 0, i1 false)
+  // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 %7, i1 true)
+  // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.128.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 0, i1 false)
+  // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.128.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 %7, i1 true)
+  // CHECK-NEXT: ret void
+  // CHECK-NEXT: }
   nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] : !llvm.ptr
   nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] l2_cache_hint = %ch : !llvm.ptr
 
-  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 0, i1 false)
-  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true)
-  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] : !llvm.ptr
-  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] l2_cache_hint = %ch : !llvm.ptr
+  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
+  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
+
+  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col_w>} : !llvm.ptr
+  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col_w>} : !llvm.ptr
+
+  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col_w_128>} : !llvm.ptr
+  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col_w_128>} : !llvm.ptr
   llvm.return
 }
 
-// CHECK-LABEL: @tma_prefetch_5d
 llvm.func @tma_prefetch_5d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %d4 : i32, %off0 : i16, %off1 : i16, %off2 : i16, %ch : i64) {
-  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 0, i1 false)
-  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
+  // CHECK-LABEL: define void @tma_prefetch_5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i16 %6, i16 %7, i16 %8, i64 %9) {
+  // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i64 0, i1 false)
+  // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i64 %9, i1 true)
+  // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i16 %6, i16 %7, i16 %8, i64 0, i1 false)
+  // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i16 %6, i16 %7, i16 %8, i64 %9, i1 true)
+  // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i16 %6, i16 %7, i64 0, i1 false)
+  // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i16 %6, i16 %7, i64 %9, i1 true)
+  // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.128.5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i16 %6, i16 %7, i64 0, i1 false)
+  // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.128.5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i16 %6, i16 %7, i64 %9, i1 true)
+  // CHECK-NEXT: ret void
+  // CHECK-NEXT: }
   nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] : !llvm.ptr
   nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] l2_cache_hint = %ch : !llvm.ptr
 
-  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 0, i1 false)
-  // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true)
-  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1, %off2] : !llvm.ptr
-  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1, %off2] l2_cache_hint = %ch : !llvm.ptr
+  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1, %off2] {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
+  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1, %off2] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
+
+  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col_w>} : !llvm.ptr
+  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col_w>} : !llvm.ptr
+
+  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col_w_128>} : !llvm.ptr
+  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col_w_128>} : !llvm.ptr
+  llvm.return
+}
+
+llvm.func @tma_prefetch_gather4_2d(%tma_desc : !llvm.ptr, %x0 : i32, %y1 : i32, %y2 : i32, %y3 : i32, %y4 : i32, %ch : i64) {
+  // CHECK-LABEL: define void @tma_prefetch_gather4_2d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i64 %6) {
+  // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.gather4.2d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i64 0, i1 false)
+  // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.gather4.2d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i64 %6, i1 true)
+  // CHECK-NEXT: ret void
+  // CHECK-NEXT: }
+  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%x0, %y1, %y2, %y3, %y4] {mode = #nvvm.tma_load_mode<tile_gather4>} : !llvm.ptr
+  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%x0, %y1, %y2, %y3, %y4] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<tile_gather4>} : !llvm.ptr
   llvm.return
 }
diff --git a/mlir/test/Target/LLVMIR/nvvm/tma_prefetch_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/tma_prefetch_invalid.mlir
new file mode 100644
index 0000000000000..23e47bd594b78
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tma_prefetch_invalid.mlir
@@ -0,0 +1,56 @@
+// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
+
+llvm.func @tma_prefetch_0d(%tma_desc : !llvm.ptr, %d0 : i32, %ch : i64) {
+  // expected-error @below {{expects coordinates between 1 to 5 dimension}}
+  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[] : !llvm.ptr
+  llvm.return
+}
+
+// -----
+
+llvm.func @tma_prefetch_2d_im2col(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %off0 : i16, %ch : i64) {
+  // expected-error @below {{to use im2col mode, the tensor has to be at least 3-dimensional}}
+  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] im2col[%off0] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
+  llvm.return
+}
+
+// -----
+
+llvm.func @tma_prefetch_5d_im2col(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %d4 : i32, %off0 : i16, %off1 : i16, %off2 : i16, %ch : i64) {
+  // expected-error @below {{im2col offsets expected 3 (provided 2)}}
+  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
+  llvm.return
+}
+
+// -----
+
+llvm.func @tma_prefetch_3d_im2col_w(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %off0 : i16) {
+  // expected-error @below {{im2col offsets expected 2 (provided 1)}}
+  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] {mode = #nvvm.tma_load_mode<im2col_w>} : !llvm.ptr
+  llvm.return
+}
+
+// -----
+
+llvm.func @tma_prefetch_4d_im2col_w_128(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %off0 : i16) {
+  // expected-error @below {{im2col offsets expected 2 (provided 1)}}
+  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0] {mode = #nvvm.tma_load_mode<im2col_w_128>} : !llvm.ptr
+  llvm.return
+}
+
+// -----
+
+llvm.func @tma_prefetch_gather4_3d(%tma_desc : !llvm.ptr, %d0 : i32) {
+  // expected-error @below {{Gather4 mode expects 5 coordinates}}
+  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0] {mode = #nvvm.tma_load_mode<tile_gather4>} : !llvm.ptr
+  llvm.return
+}
+
+// -----
+
+llvm.func @tma_prefetch_gather4_2d(%tma_desc : !llvm.ptr, %x0 : i32, %y1 : i32, %y2 : i32, %y3 : i32, %y4 : i32, %off0 : i16, %ch : i64) {
+  // expected-error @below {{im2col offsets expected 0 (provided 1)}}
+  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%x0, %y1, %y2, %y3, %y4] im2col[%off0] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<tile_gather4>} : !llvm.ptr
+  llvm.return
+}
+
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 33398cfb92429..53c7d5f741575 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -86,32 +86,6 @@ llvm.func @nvvm_fence_proxy_release() {
   llvm.return
 }
 
-// -----
-
-llvm.func @tma_prefetch_0d(%tma_desc : !llvm.ptr, %d0 : i32, %ch : i64) {
-  // expected-error @below {{expects coordinates between 1 to 5 dimension}}
-  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[] : !llvm.ptr
-  llvm.return
-}
-
-// -----
-
-llvm.func @tma_prefetch_2d_im2col(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %off0 : i16, %ch : i64) {
-  // expected-error @below {{to use im2col mode, the tensor has to be at least 3-dimensional}}
-  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] im2col[%off0] l2_cache_hint = %ch : !llvm.ptr
-  llvm.return
-}
-
-// -----
-
-llvm.func @tma_prefetch_5d_im2col(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %d4 : i32, %off0 : i16, %off1 : i16, %off2 : i16, %ch : i64) {
-  // expected-error @below {{im2col offsets must be 2 less than number of coordinates}}
-  nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1] : !llvm.ptr
-  llvm.return
-}
-
-// -----
-
 llvm.func @tma_reduce_0d(%src : !llvm.ptr<3>, %tma_desc : !llvm.ptr, %ch : i64) {
   // expected-error @below {{expects coordinates between 1 to 5 dimension}}
   nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[] {redKind = #nvvm.tma_redux_kind<add>}: !llvm.ptr, !llvm.ptr<3>



More information about the Mlir-commits mailing list