[Mlir-commits] [mlir] [MLIR][NVVM] Update TMA tensor prefetch Op (PR #153464)
Guray Ozen
llvmlistbot at llvm.org
Thu Aug 14 03:48:24 PDT 2025
================
@@ -1399,28 +1430,60 @@ mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
return {id, std::move(args)};
}
-llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
- bool isIm2Col) {
- switch (tensorDims) {
- case 1:
- return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
- case 2:
- return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
- case 3:
- return isIm2Col
- ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
- : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
- case 4:
- return isIm2Col
- ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
- : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
- case 5:
- return isIm2Col
- ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
- : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
- default:
- llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp.");
- }
+#define GET_TMA_OPCODE(op, mode, dim) \
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d
+
+mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ // Fill the Intrinsic Args
+ args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
+
+ for (auto v : thisOp.getCoordinates())
+ args.push_back(mt.lookupValue(v));
+ for (auto v : thisOp.getIm2colOffsets())
+ args.push_back(mt.lookupValue(v));
+
+ mlir::Value cacheHint = thisOp.getL2CacheHint();
+ const bool hasCacheHint = static_cast<bool>(cacheHint);
+ llvm::Value *i64Unused =
+ llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
+ args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
+ args.push_back(builder.getInt1(hasCacheHint));
+
+#define NI llvm::Intrinsic::not_intrinsic
+#define TILE(D) GET_TMA_OPCODE(prefetch, tile, D)
+#define IM2COL(D) GET_TMA_OPCODE(prefetch, im2col, D)
+#define IM2COLW(D) GET_TMA_OPCODE(prefetch, im2col_w, D)
+#define IM2COLW128(D) GET_TMA_OPCODE(prefetch, im2col_w_128, D)
+#define GATHER4(D) GET_TMA_OPCODE(prefetch, tile_gather4, D)
+
+ static constexpr llvm::Intrinsic::ID IDTable[][6] = {
+ {NI, TILE(1), TILE(2), TILE(3), TILE(4), TILE(5)}, // tile
+ {NI, NI, NI, IM2COL(3), IM2COL(4), IM2COL(5)}, // im2col
+ {NI, NI, NI, IM2COLW(3), IM2COLW(4), IM2COLW(5)}, // im2col_w
+ {NI, NI, NI, IM2COLW128(3), IM2COLW128(4), IM2COLW128(5)}, // im2col_w128
+ {NI, NI, NI, NI, NI, GATHER4(2)}, // tile_gather4
+ };
+ static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
+ "TMALoadModes must match number of rows in IDTable");
+
+ size_t mode = static_cast<size_t>(thisOp.getMode());
+ size_t dim = thisOp.getCoordinates().size();
+ llvm::Intrinsic::ID id = IDTable[mode][dim];
+ if (id == llvm::Intrinsic::not_intrinsic)
+ llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
+
+ return {id, std::move(args)};
+
+#undef GATHER4
+#undef IM2COLW128
+#undef IM2COLW
+#undef IM2COL
+#undef TILE
+#undef NI
----------------
grypp wrote:
We can create constexpr tables and use them, they are zero-cost as well. I drafted an example below:
```
constexpr llvm::Intrinsic::ID NI = llvm::Intrinsic::not_intrinsic;
static constexpr llvm::Intrinsic::ID IDTable[][6] = {
// tile
{
NI,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d,
},
// im2col
{
NI, NI, NI,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d,
},
// im2col_w
{
NI, NI, NI,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d,
},
// im2col_w_128
{
NI, NI, NI,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d,
},
// tile_gather4
{
NI, NI, NI, NI, NI,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d,
},
};
static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
"TMALoadModes must match number of rows in IDTable");
const size_t mode = static_cast<size_t>(thisOp.getMode());
const size_t dim = thisOp.getCoordinates().size();
if (mode >= std::size(IDTable) || dim >= std::size(IDTable[0]))
llvm_unreachable("Mode or dimension out of range for CpAsyncBulkTensorPrefetchOp.");
llvm::Intrinsic::ID id = IDTable[mode][dim];
if (id == llvm::Intrinsic::not_intrinsic)
llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
```
https://github.com/llvm/llvm-project/pull/153464
More information about the Mlir-commits
mailing list