[Mlir-commits] [mlir] [MLIR][NVVM] Add Op for TMA Prefetch (PR #116232)
Durgadoss R
llvmlistbot at llvm.org
Thu Nov 14 11:28:42 PST 2024
https://github.com/durga4github updated https://github.com/llvm/llvm-project/pull/116232
>From d94f0fd1b3091a959f0c11f462d23cf29c0508ea Mon Sep 17 00:00:00 2001
From: Durgadoss R <durgadossr at nvidia.com>
Date: Tue, 12 Nov 2024 17:35:39 +0530
Subject: [PATCH] [MLIR][NVVM] Add Op for TMA Prefetch
PR #115527 adds intrinsics for TMA prefetch.
This patch adds an NVVM Dialect Op for the same.
Lit tests to verify the lowering to LLVM intrinsics
as well as verifier tests (for invalid cases) are
added.
Signed-off-by: Durgadoss R <durgadossr at nvidia.com>
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 68 +++++++++++++++++++++
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 57 ++++++++++++++---
mlir/test/Target/LLVMIR/nvvmir-invalid.mlir | 26 +++++++-
mlir/test/Target/LLVMIR/nvvmir.mlir | 62 +++++++++++++++++++
4 files changed, 203 insertions(+), 10 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 7cb4b5c346ad97..6b462de144d1ff 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1949,6 +1949,74 @@ def NVVM_PrefetchTensorMapOp : NVVM_Op<"prefetch.tensormap",
}];
}
+def NVVM_CpAsyncBulkTensorPrefetchOp :
+ NVVM_Op<"cp.async.bulk.tensor.prefetch", [AttrSizedOperandSegments]> {
+ let arguments = (ins
+ LLVM_AnyPointer:$tmaDescriptor,
+ Variadic<I32>:$coordinates,
+ Variadic<I16>:$im2colOffsets,
+ Optional<I64>:$l2CacheHint);
+
+ let description = [{
+ Initiates an asynchronous prefetch operation on the tensor data from global
+ memory to L2 cache.
+
+ The Op has two modes:
+ 1) Tiled Mode: It's the default mode. The source multi-dimensional tensor
+ layout is preserved at the destination.
+
+ 2) Im2col Mode: This mode is used when `im2colOffsets` operands are present.
+ the elements in the Bounding Box of the source tensor are rearranged into
+ columns at the destination. In this mode, the tensor has to be at least
+ 3-dimensional.
+
+ The `l2CacheHint` operand is optional, and it is used to specify cache
+ eviction policy that may be used during the memory access.
+
+ [For more information, see PTX ISA]
+ (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-prefetch-tensor)
+ }];
+
+ let assemblyFormat = [{
+ $tmaDescriptor `,`
+ `box` `[`$coordinates `]`
+ (`im2col` `[` $im2colOffsets^ `]` )?
+ (`l2_cache_hint` `=` $l2CacheHint^ )?
+ attr-dict `:` type($tmaDescriptor)
+ }];
+
+ let extraClassDeclaration = [{
+ static llvm::Intrinsic::ID getIntrinsicID(int tensorDims, bool isIm2Col);
+ }];
+
+ let hasVerifier = 1;
+
+ string llvmBuilder = [{
+ // Arguments to the intrinsic:
+ // tmaDesc, tensorDims, im2colOffsets
+ // cache_hint(if applicable) and flag(boolean)
+ llvm::SmallVector<llvm::Value *> translatedOperands;
+ translatedOperands.push_back($tmaDescriptor);
+
+ for (auto v : op.getCoordinates())
+ translatedOperands.push_back(moduleTranslation.lookupValue(v));
+
+ for (auto v : op.getIm2colOffsets())
+ translatedOperands.push_back(moduleTranslation.lookupValue(v));
+
+ llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
+ auto *i64Undef = llvm::UndefValue::get(llvm::IntegerType::get(ctx, 64));
+
+ bool isCacheHint = op.getL2CacheHint() ? true : false;
+ translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Undef);
+ translatedOperands.push_back(builder.getInt1(isCacheHint));
+
+ auto intId = NVVM::CpAsyncBulkTensorPrefetchOp::getIntrinsicID(
+ op.getCoordinates().size(), op.getIm2colOffsets().size() > 0);
+ createIntrinsicCall(builder, intId, translatedOperands);
+ }];
+}
+
//===----------------------------------------------------------------------===//
// NVVM Wgmma Ops
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 5ab64ea1b2097a..d28194d5c00298 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -75,22 +75,32 @@ ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
-LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
- if (getCoordinates().empty() || getCoordinates().size() > 5)
- return emitError("expects coordinates between 1 to 5 dimension");
-
- // Check for im2col mode
- if (!getIm2colOffsets().empty()) {
- if (getCoordinates().size() < 3)
+// This verifier is shared across:
+// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load) and
+// CpAsyncBulkTensorPrefetchOp (TMA Prefetch) Ops.
+static LogicalResult CpAsyncBulkTensorCommonVerifier(size_t tensorDims,
+ size_t numIm2ColOffsets,
+ Location loc) {
+ if (tensorDims < 1 || tensorDims > 5)
+ return emitError(loc, "expects coordinates between 1 to 5 dimension");
+
+ if (numIm2ColOffsets) {
+ if (tensorDims < 3)
return emitError(
+ loc,
"to use im2col mode, the tensor has to be at least 3-dimensional");
- if (getCoordinates().size() != (getIm2colOffsets().size() + 2))
+ if (tensorDims != (numIm2ColOffsets + 2))
return emitError(
- "im2col offsets must be 2 less than number of coordinates");
+ loc, "im2col offsets must be 2 less than number of coordinates");
}
return success();
}
+LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
+ return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(),
+ getIm2colOffsets().size(), getLoc());
+}
+
LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
if (getCoordinates().size() > 5)
return emitError("Maximum 5 coordinates and dimension is supported.");
@@ -108,6 +118,11 @@ LogicalResult CpAsyncOp::verify() {
return success();
}
+LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
+ return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(),
+ getIm2colOffsets().size(), getLoc());
+}
+
// Given the element type of an operand and whether or not it is an accumulator,
// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
// operand's element type.
@@ -1055,6 +1070,30 @@ LogicalResult NVVM::BarrierOp::verify() {
return success();
}
+llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
+ bool isIm2Col) {
+ switch (tensorDims) {
+ case 1:
+ return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
+ case 2:
+ return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
+ case 3:
+ return isIm2Col
+ ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
+ : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
+ case 4:
+ return isIm2Col
+ ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
+ : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
+ case 5:
+ return isIm2Col
+ ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
+ : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
+ default:
+ llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp.");
+ }
+}
+
//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 0e563808da970b..58282adf4dda85 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -30,4 +30,28 @@ llvm.func @nvvm_fence_proxy_release() {
// expected-error @below {{'nvvm.fence.proxy.release' op uni-directional proxies only support tensormap for to_proxy attribute}}
nvvm.fence.proxy.release #nvvm.mem_scope<cta> from_proxy=#nvvm.proxy_kind<generic> to_proxy=#nvvm.proxy_kind<generic>
llvm.return
-}
\ No newline at end of file
+}
+
+// -----
+
+llvm.func @tma_prefetch_0d(%tma_desc : !llvm.ptr, %d0 : i32, %ch : i64) {
+ // expected-error @below {{expects coordinates between 1 to 5 dimension}}
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[] : !llvm.ptr
+ llvm.return
+}
+
+// -----
+
+llvm.func @tma_prefetch_2d_im2col(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %off0 : i16, %ch : i64) {
+ // expected-error @below {{to use im2col mode, the tensor has to be at least 3-dimensional}}
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] im2col[%off0] l2_cache_hint = %ch : !llvm.ptr
+ llvm.return
+}
+
+// -----
+
+llvm.func @tma_prefetch_5d_im2col(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %d4 : i32, %off0 : i16, %off1 : i16, %off2 : i16, %ch : i64) {
+ // expected-error @below {{im2col offsets must be 2 less than number of coordinates}}
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1] : !llvm.ptr
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 75ce958b43fd34..e5ea03ff7e0017 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -715,3 +715,65 @@ llvm.func @nvvm_breakpoint() {
nvvm.breakpoint
llvm.return
}
+
+// -----
+
+// CHECK-LABEL: @tma_prefetch_1d
+llvm.func @tma_prefetch_1d(%tma_desc : !llvm.ptr, %d0 : i32, %ch : i64) {
+ // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %{{.*}}, i64 undef, i1 false)
+ // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %{{.*}}, i64 %{{.*}}, i1 true)
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0] : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0] l2_cache_hint = %ch : !llvm.ptr
+ llvm.return
+}
+
+// CHECK-LABEL: @tma_prefetch_2d
+llvm.func @tma_prefetch_2d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %ch : i64) {
+ // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i64 undef, i1 false)
+ // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] l2_cache_hint = %ch : !llvm.ptr
+ llvm.return
+}
+
+// CHECK-LABEL: @tma_prefetch_3d
+llvm.func @tma_prefetch_3d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %off0 : i16, %ch : i64) {
+ // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 undef, i1 false)
+ // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] l2_cache_hint = %ch : !llvm.ptr
+
+ // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i64 undef, i1 false)
+ // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true)
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] l2_cache_hint = %ch : !llvm.ptr
+ llvm.return
+}
+
+// CHECK-LABEL: @tma_prefetch_4d
+llvm.func @tma_prefetch_4d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %off0 : i16, %off1 : i16, %ch : i64) {
+ // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 undef, i1 false)
+ // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] l2_cache_hint = %ch : !llvm.ptr
+
+ // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 undef, i1 false)
+ // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true)
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] l2_cache_hint = %ch : !llvm.ptr
+ llvm.return
+}
+
+// CHECK-LABEL: @tma_prefetch_5d
+llvm.func @tma_prefetch_5d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %d4 : i32, %off0 : i16, %off1 : i16, %off2 : i16, %ch : i64) {
+ // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 undef, i1 false)
+ // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] l2_cache_hint = %ch : !llvm.ptr
+
+ // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 undef, i1 false)
+ // CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true)
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1, %off2] : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1, %off2] l2_cache_hint = %ch : !llvm.ptr
+ llvm.return
+}
More information about the Mlir-commits
mailing list