[Mlir-commits] [mlir] [MLIR][NVVM] Update TMA tensor prefetch Op (PR #153464)
Durgadoss R
llvmlistbot at llvm.org
Wed Aug 20 06:35:10 PDT 2025
https://github.com/durga4github updated https://github.com/llvm/llvm-project/pull/153464
>From 0e2866770b43382c4c0ec644fe5a260602d5fc11 Mon Sep 17 00:00:00 2001
From: Durgadoss R <durgadossr at nvidia.com>
Date: Tue, 12 Aug 2025 19:06:16 +0530
Subject: [PATCH] [MLIR][NVVM] Update TMA tensor prefetch Op
This patch updates the TMA Tensor prefetch Op
with support for im2col_w/w128 and tile_gather4 modes.
This completes support for all modes available in
Blackwell.
* lit tests are added for all possible combinations.
* The invalid tests are moved to a separate file
with more coverage.
Signed-off-by: Durgadoss R <durgadossr at nvidia.com>
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 94 +++++++++-----
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 118 ++++++++++++++----
.../test/Target/LLVMIR/nvvm/tma_prefetch.mlir | 117 ++++++++++++-----
.../LLVMIR/nvvm/tma_prefetch_invalid.mlir | 56 +++++++++
mlir/test/Target/LLVMIR/nvvmir-invalid.mlir | 26 ----
5 files changed, 292 insertions(+), 119 deletions(-)
create mode 100644 mlir/test/Target/LLVMIR/nvvm/tma_prefetch_invalid.mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index f9cd58de8915f..40e264202a68c 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2256,6 +2256,56 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
// NVVM TMA Ops
//===----------------------------------------------------------------------===//
+// List of modes supported for TMA Load and Prefetch Ops
+def TMALoadModeTile : I32EnumAttrCase<"TILE", 0, "tile">;
+def TMALoadModeIm2Col : I32EnumAttrCase<"IM2COL", 1, "im2col">;
+def TMALoadModeIm2ColW : I32EnumAttrCase<"IM2COL_W", 2, "im2col_w">;
+def TMALoadModeIm2ColW128 : I32EnumAttrCase<"IM2COL_W_128", 3, "im2col_w_128">;
+def TMALoadModeTileGather4 : I32EnumAttrCase<"TILE_GATHER4", 4, "tile_gather4">;
+
+def TMALoadMode : I32EnumAttr<"TMALoadMode", "NVVM TMA Load Mode",
+ [TMALoadModeTile, TMALoadModeIm2Col,
+ TMALoadModeIm2ColW, TMALoadModeIm2ColW128,
+ TMALoadModeTileGather4]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+def TMALoadModeAttr : EnumAttr<NVVM_Dialect, TMALoadMode, "tma_load_mode"> {
+ let summary = "List of Load-Modes supported for TMA Tensor Ops";
+ let description = [{
+ TMA Tensor Ops support the following modes, when copying data from
+ global memory to shared memory (i.e. load):
+
+ Tile Mode: It's the default mode. The source multi-dimensional tensor
+ layout is preserved at the destination.
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-tiled-mode)
+
+ Im2col Mode: This mode is used when `im2colOffsets` operands are present.
+ The elements in the Bounding Box of the source tensor are rearranged into
+ columns at the destination. In this mode, the tensor has to be at least
+ 3-dimensional. The number of `im2colOffsets` is `dims - 2` where `dims`
+ is the dimension of the tensor.
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-im2col-mode)
+
+ Im2col_W Mode: This mode is similar to Im2Col mode with the restriction that
+ elements are accessed across the W dimension only. The number of `im2colOffsets`
+ are always two, referred as `wHalo` and `wOffset`.
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-im2col-w-w128-modes)
+
+ Im2col_W_128 Mode: This mode is similar to Im2Col_W mode with the number of
+ elements accessed across the W dimension is always 128 only.
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-im2col-w-w128-modes)
+
+ Tile_Gather4 Mode: This mode is similar to Tile mode but works only on 2D tensor.
+ In gather4 mode, four rows in the source 2D tensor are combined to form a single
+ 2D tensor at the destination. This mode requires five co-ordinates. The first one
+ represents the column-index followed by four row indices.
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-tiled-scatter4-gather4-modes)
+ }];
+
+ let assemblyFormat = "`<` $value `>`";
+}
+
def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">,
Arguments<(ins )> {
let assemblyFormat = "attr-dict";
@@ -2524,23 +2574,16 @@ def NVVM_CpAsyncBulkPrefetchOp : NVVM_Op<"cp.async.bulk.prefetch"> {
def NVVM_CpAsyncBulkTensorPrefetchOp :
NVVM_Op<"cp.async.bulk.tensor.prefetch", [AttrSizedOperandSegments]> {
let arguments = (ins
- LLVM_AnyPointer:$tmaDescriptor,
+ LLVM_PointerGeneric:$tmaDescriptor,
Variadic<I32>:$coordinates,
Variadic<I16>:$im2colOffsets,
+ DefaultValuedAttr<TMALoadModeAttr, "TMALoadMode::TILE">:$mode,
Optional<I64>:$l2CacheHint);
let description = [{
Initiates an asynchronous prefetch operation on the tensor data from global
- memory to L2 cache.
-
- The Op has two modes:
- 1) Tiled Mode: It's the default mode. The source multi-dimensional tensor
- layout is preserved at the destination.
-
- 2) Im2col Mode: This mode is used when `im2colOffsets` operands are present.
- the elements in the Bounding Box of the source tensor are rearranged into
- columns at the destination. In this mode, the tensor has to be at least
- 3-dimensional.
+ memory to L2 cache. This Op supports all the load modes specified in
+ `TMALoadMode`.
The `l2CacheHint` operand is optional, and it is used to specify cache
eviction policy that may be used during the memory access.
@@ -2557,34 +2600,17 @@ def NVVM_CpAsyncBulkTensorPrefetchOp :
}];
let extraClassDeclaration = [{
- static llvm::Intrinsic::ID getIntrinsicID(int tensorDims, bool isIm2Col);
+ static mlir::NVVM::IDArgPair
+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase& builder);
}];
let hasVerifier = 1;
string llvmBuilder = [{
- // Arguments to the intrinsic:
- // tmaDesc, tensorDims, im2colOffsets
- // cache_hint(if applicable) and flag(boolean)
- llvm::SmallVector<llvm::Value *> translatedOperands;
- translatedOperands.push_back($tmaDescriptor);
-
- for (auto v : op.getCoordinates())
- translatedOperands.push_back(moduleTranslation.lookupValue(v));
-
- for (auto v : op.getIm2colOffsets())
- translatedOperands.push_back(moduleTranslation.lookupValue(v));
-
- llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
- auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
-
- bool isCacheHint = op.getL2CacheHint() ? true : false;
- translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Unused);
- translatedOperands.push_back(builder.getInt1(isCacheHint));
-
- auto intId = NVVM::CpAsyncBulkTensorPrefetchOp::getIntrinsicID(
- op.getCoordinates().size(), op.getIm2colOffsets().size() > 0);
- createIntrinsicCall(builder, intId, translatedOperands);
+ auto [id, args] = NVVM::CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
+ *op, moduleTranslation, builder);
+ createIntrinsicCall(builder, id, args);
}];
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index dbcc738b4419f..f8ac6fc413d77 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -50,7 +50,6 @@ using namespace NVVM;
// This verifier is shared among the following Ops:
// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
-// CpAsyncBulkTensorPrefetchOp (TMA Prefetch)
// CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
bool isIm2Col,
@@ -98,11 +97,47 @@ LogicalResult CpAsyncOp::verify() {
return success();
}
+// This verify params can be shared across TMA Load and Prefetch Ops.
+static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff,
+ TMALoadMode mode, Location loc) {
+ if (tensorDims < 1 || tensorDims > 5)
+ return emitError(loc, "expects coordinates between 1 to 5 dimension");
+
+ auto checkTMALoadParams = [&](TMALoadMode mode, bool isIm2col,
+ size_t expectedIm2colOff) -> LogicalResult {
+ if (isIm2col && (tensorDims < 3))
+ return emitError(loc)
+ << "to use " << stringifyEnum(mode)
+ << " mode, the tensor has to be at least 3-dimensional";
+
+ if (numIm2colOff != expectedIm2colOff)
+ return emitError(loc) << " im2col offsets expected " << expectedIm2colOff
+ << " (provided " << numIm2colOff << ")";
+
+ return success();
+ };
+
+ switch (mode) {
+ case TMALoadMode::TILE:
+ return checkTMALoadParams(mode, false, 0);
+ case TMALoadMode::IM2COL:
+ return checkTMALoadParams(mode, true, tensorDims - 2);
+ case TMALoadMode::IM2COL_W:
+ case TMALoadMode::IM2COL_W_128:
+ return checkTMALoadParams(mode, true, 2);
+ case TMALoadMode::TILE_GATHER4:
+ return (tensorDims == 5)
+ ? checkTMALoadParams(mode, false, 0)
+ : emitError(loc, "Gather4 mode expects 5 coordinates");
+ default:
+ return emitError(loc, "Invalid LoadMode in CpAsyncBulkTensorPrefetchOp.");
+ }
+ return success();
+}
+
LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
- size_t numIm2ColOffsets = getIm2colOffsets().size();
- bool isIm2Col = numIm2ColOffsets > 0;
- return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
- numIm2ColOffsets, getLoc());
+ return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
+ getMode(), getLoc());
}
LogicalResult CpAsyncBulkTensorReduceOp::verify() {
@@ -1433,28 +1468,57 @@ mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
return {id, std::move(args)};
}
-llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
- bool isIm2Col) {
- switch (tensorDims) {
- case 1:
- return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
- case 2:
- return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
- case 3:
- return isIm2Col
- ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
- : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
- case 4:
- return isIm2Col
- ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
- : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
- case 5:
- return isIm2Col
- ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
- : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
- default:
- llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp.");
- }
+mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ // Fill the Intrinsic Args
+ args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
+
+ for (auto v : thisOp.getCoordinates())
+ args.push_back(mt.lookupValue(v));
+ for (auto v : thisOp.getIm2colOffsets())
+ args.push_back(mt.lookupValue(v));
+
+ mlir::Value cacheHint = thisOp.getL2CacheHint();
+ const bool hasCacheHint = static_cast<bool>(cacheHint);
+ llvm::Value *i64Unused =
+ llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
+ args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
+ args.push_back(builder.getInt1(hasCacheHint));
+
+ const unsigned NI = llvm::Intrinsic::not_intrinsic;
+ static constexpr llvm::Intrinsic::ID IDTable[][6] = {
+ {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d},
+ {NI, NI, NI,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d},
+ {NI, NI, NI,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d},
+ {NI, NI, NI,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d},
+ {NI, NI, NI, NI, NI,
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}};
+
+ static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
+ "TMALoadModes must match number of rows in IDTable");
+ size_t mode = static_cast<size_t>(thisOp.getMode());
+ size_t dim = thisOp.getCoordinates().size();
+ llvm::Intrinsic::ID id = IDTable[mode][dim];
+ if (id == llvm::Intrinsic::not_intrinsic)
+ llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
+
+ return {id, std::move(args)};
}
#define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \
diff --git a/mlir/test/Target/LLVMIR/nvvm/tma_prefetch.mlir b/mlir/test/Target/LLVMIR/nvvm/tma_prefetch.mlir
index bfd952636ffbe..536b52b034db8 100644
--- a/mlir/test/Target/LLVMIR/nvvm/tma_prefetch.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/tma_prefetch.mlir
@@ -1,70 +1,123 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
-// CHECK-LABEL: @tma_bulk_prefetch
llvm.func @tma_bulk_prefetch(%src : !llvm.ptr<1>, %size : i32, %ch : i64) {
- // CHECK: call void @llvm.nvvm.cp.async.bulk.prefetch.L2(ptr addrspace(1) %{{.*}}, i32 %{{.*}}, i64 0, i1 false)
- // CHECK: call void @llvm.nvvm.cp.async.bulk.prefetch.L2(ptr addrspace(1) %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
+ // CHECK-LABEL: define void @tma_bulk_prefetch(ptr addrspace(1) %0, i32 %1, i64 %2) {
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.prefetch.L2(ptr addrspace(1) %0, i32 %1, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.prefetch.L2(ptr addrspace(1) %0, i32 %1, i64 %2, i1 true)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
nvvm.cp.async.bulk.prefetch %src, %size : !llvm.ptr<1>
nvvm.cp.async.bulk.prefetch %src, %size l2_cache_hint = %ch : !llvm.ptr<1>
llvm.return
}
-// CHECK-LABEL: @tma_prefetch_1d
llvm.func @tma_prefetch_1d(%tma_desc : !llvm.ptr, %d0 : i32, %ch : i64) {
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %{{.*}}, i64 0, i1 false)
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %{{.*}}, i64 %{{.*}}, i1 true)
+ // CHECK-LABEL: define void @tma_prefetch_1d(ptr %0, i32 %1, i64 %2) {
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %1, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %1, i64 %2, i1 true)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0] l2_cache_hint = %ch : !llvm.ptr
llvm.return
}
-// CHECK-LABEL: @tma_prefetch_2d
llvm.func @tma_prefetch_2d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %ch : i64) {
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i64 0, i1 false)
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
- nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] : !llvm.ptr
+ // CHECK-LABEL: define void @tma_prefetch_2d(ptr %0, i32 %1, i32 %2, i64 %3) {
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %1, i32 %2, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %1, i32 %2, i64 %3, i1 true)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] {mode = #nvvm.tma_load_mode<tile>} : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] l2_cache_hint = %ch : !llvm.ptr
llvm.return
}
-// CHECK-LABEL: @tma_prefetch_3d
-llvm.func @tma_prefetch_3d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %off0 : i16, %ch : i64) {
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 0, i1 false)
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
+llvm.func @tma_prefetch_3d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %off0 : i16, %off1 : i16, %ch : i64) {
+ // CHECK-LABEL: define void @tma_prefetch_3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i16 %5, i64 %6) {
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %1, i32 %2, i32 %3, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %1, i32 %2, i32 %3, i64 %6, i1 true)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i64 %6, i1 true)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i16 %5, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i16 %5, i64 %6, i1 true)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.128.3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i16 %5, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.128.3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i16 %5, i64 %6, i1 true)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] l2_cache_hint = %ch : !llvm.ptr
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i64 0, i1 false)
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true)
- nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] : !llvm.ptr
- nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] l2_cache_hint = %ch : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
+
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col_w>} : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0, %off1] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col_w>} : !llvm.ptr
+
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col_w_128>} : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0, %off1] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col_w_128>} : !llvm.ptr
llvm.return
}
-// CHECK-LABEL: @tma_prefetch_4d
llvm.func @tma_prefetch_4d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %off0 : i16, %off1 : i16, %ch : i64) {
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 0, i1 false)
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
+ // CHECK-LABEL: define void @tma_prefetch_4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 %7) {
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i64 %7, i1 true)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 %7, i1 true)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 %7, i1 true)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.128.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.128.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 %7, i1 true)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] l2_cache_hint = %ch : !llvm.ptr
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 0, i1 false)
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true)
- nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] : !llvm.ptr
- nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] l2_cache_hint = %ch : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
+
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col_w>} : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col_w>} : !llvm.ptr
+
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col_w_128>} : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col_w_128>} : !llvm.ptr
llvm.return
}
-// CHECK-LABEL: @tma_prefetch_5d
llvm.func @tma_prefetch_5d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %d4 : i32, %off0 : i16, %off1 : i16, %off2 : i16, %ch : i64) {
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 0, i1 false)
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
+ // CHECK-LABEL: define void @tma_prefetch_5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i16 %6, i16 %7, i16 %8, i64 %9) {
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i64 %9, i1 true)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i16 %6, i16 %7, i16 %8, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i16 %6, i16 %7, i16 %8, i64 %9, i1 true)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i16 %6, i16 %7, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i16 %6, i16 %7, i64 %9, i1 true)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.128.5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i16 %6, i16 %7, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.128.5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i16 %6, i16 %7, i64 %9, i1 true)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] l2_cache_hint = %ch : !llvm.ptr
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 0, i1 false)
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true)
- nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1, %off2] : !llvm.ptr
- nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1, %off2] l2_cache_hint = %ch : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1, %off2] {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1, %off2] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
+
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col_w>} : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col_w>} : !llvm.ptr
+
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col_w_128>} : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col_w_128>} : !llvm.ptr
+ llvm.return
+}
+
+llvm.func @tma_prefetch_gather4_2d(%tma_desc : !llvm.ptr, %x0 : i32, %y1 : i32, %y2 : i32, %y3 : i32, %y4 : i32, %ch : i64) {
+ // CHECK-LABEL: define void @tma_prefetch_gather4_2d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i64 %6) {
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.gather4.2d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.gather4.2d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i64 %6, i1 true)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%x0, %y1, %y2, %y3, %y4] {mode = #nvvm.tma_load_mode<tile_gather4>} : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%x0, %y1, %y2, %y3, %y4] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<tile_gather4>} : !llvm.ptr
llvm.return
}
diff --git a/mlir/test/Target/LLVMIR/nvvm/tma_prefetch_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/tma_prefetch_invalid.mlir
new file mode 100644
index 0000000000000..23e47bd594b78
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tma_prefetch_invalid.mlir
@@ -0,0 +1,56 @@
+// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
+
+llvm.func @tma_prefetch_0d(%tma_desc : !llvm.ptr, %d0 : i32, %ch : i64) {
+ // expected-error @below {{expects coordinates between 1 to 5 dimension}}
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[] : !llvm.ptr
+ llvm.return
+}
+
+// -----
+
+llvm.func @tma_prefetch_2d_im2col(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %off0 : i16, %ch : i64) {
+ // expected-error @below {{to use im2col mode, the tensor has to be at least 3-dimensional}}
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] im2col[%off0] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
+ llvm.return
+}
+
+// -----
+
+llvm.func @tma_prefetch_5d_im2col(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %d4 : i32, %off0 : i16, %off1 : i16, %off2 : i16, %ch : i64) {
+ // expected-error @below {{im2col offsets expected 3 (provided 2)}}
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
+ llvm.return
+}
+
+// -----
+
+llvm.func @tma_prefetch_3d_im2col_w(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %off0 : i16) {
+ // expected-error @below {{im2col offsets expected 2 (provided 1)}}
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] {mode = #nvvm.tma_load_mode<im2col_w>} : !llvm.ptr
+ llvm.return
+}
+
+// -----
+
+llvm.func @tma_prefetch_4d_im2col_w_128(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %off0 : i16) {
+ // expected-error @below {{im2col offsets expected 2 (provided 1)}}
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0] {mode = #nvvm.tma_load_mode<im2col_w_128>} : !llvm.ptr
+ llvm.return
+}
+
+// -----
+
+llvm.func @tma_prefetch_gather4_3d(%tma_desc : !llvm.ptr, %d0 : i32) {
+ // expected-error @below {{Gather4 mode expects 5 coordinates}}
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0] {mode = #nvvm.tma_load_mode<tile_gather4>} : !llvm.ptr
+ llvm.return
+}
+
+// -----
+
+llvm.func @tma_prefetch_gather4_2d(%tma_desc : !llvm.ptr, %x0 : i32, %y1 : i32, %y2 : i32, %y3 : i32, %y4 : i32, %off0 : i16, %ch : i64) {
+ // expected-error @below {{im2col offsets expected 0 (provided 1)}}
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%x0, %y1, %y2, %y3, %y4] im2col[%off0] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<tile_gather4>} : !llvm.ptr
+ llvm.return
+}
+
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 33398cfb92429..53c7d5f741575 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -86,32 +86,6 @@ llvm.func @nvvm_fence_proxy_release() {
llvm.return
}
-// -----
-
-llvm.func @tma_prefetch_0d(%tma_desc : !llvm.ptr, %d0 : i32, %ch : i64) {
- // expected-error @below {{expects coordinates between 1 to 5 dimension}}
- nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[] : !llvm.ptr
- llvm.return
-}
-
-// -----
-
-llvm.func @tma_prefetch_2d_im2col(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %off0 : i16, %ch : i64) {
- // expected-error @below {{to use im2col mode, the tensor has to be at least 3-dimensional}}
- nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] im2col[%off0] l2_cache_hint = %ch : !llvm.ptr
- llvm.return
-}
-
-// -----
-
-llvm.func @tma_prefetch_5d_im2col(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %d4 : i32, %off0 : i16, %off1 : i16, %off2 : i16, %ch : i64) {
- // expected-error @below {{im2col offsets must be 2 less than number of coordinates}}
- nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1] : !llvm.ptr
- llvm.return
-}
-
-// -----
-
llvm.func @tma_reduce_0d(%src : !llvm.ptr<3>, %tma_desc : !llvm.ptr, %ch : i64) {
// expected-error @below {{expects coordinates between 1 to 5 dimension}}
nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[] {redKind = #nvvm.tma_redux_kind<add>}: !llvm.ptr, !llvm.ptr<3>
More information about the Mlir-commits
mailing list