[Mlir-commits] [mlir] [MLIR][NVVM] Update TMA tensor prefetch Op (PR #153464)
Guray Ozen
llvmlistbot at llvm.org
Thu Aug 14 03:44:25 PDT 2025
================
@@ -1399,28 +1430,60 @@ mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
return {id, std::move(args)};
}
-llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
- bool isIm2Col) {
- switch (tensorDims) {
- case 1:
- return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
- case 2:
- return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
- case 3:
- return isIm2Col
- ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
- : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
- case 4:
- return isIm2Col
- ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
- : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
- case 5:
- return isIm2Col
- ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
- : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
- default:
- llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp.");
- }
+#define GET_TMA_OPCODE(op, mode, dim) \
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d
+
+mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ // Fill the Intrinsic Args
+ args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
+
+ for (auto v : thisOp.getCoordinates())
+ args.push_back(mt.lookupValue(v));
+ for (auto v : thisOp.getIm2colOffsets())
+ args.push_back(mt.lookupValue(v));
+
+ mlir::Value cacheHint = thisOp.getL2CacheHint();
+ const bool hasCacheHint = static_cast<bool>(cacheHint);
+ llvm::Value *i64Unused =
+ llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
+ args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
+ args.push_back(builder.getInt1(hasCacheHint));
+
+#define NI llvm::Intrinsic::not_intrinsic
+#define TILE(D) GET_TMA_OPCODE(prefetch, tile, D)
+#define IM2COL(D) GET_TMA_OPCODE(prefetch, im2col, D)
+#define IM2COLW(D) GET_TMA_OPCODE(prefetch, im2col_w, D)
+#define IM2COLW128(D) GET_TMA_OPCODE(prefetch, im2col_w_128, D)
+#define GATHER4(D) GET_TMA_OPCODE(prefetch, tile_gather4, D)
+
+ static constexpr llvm::Intrinsic::ID IDTable[][6] = {
+ {NI, TILE(1), TILE(2), TILE(3), TILE(4), TILE(5)}, // tile
+ {NI, NI, NI, IM2COL(3), IM2COL(4), IM2COL(5)}, // im2col
+ {NI, NI, NI, IM2COLW(3), IM2COLW(4), IM2COLW(5)}, // im2col_w
+ {NI, NI, NI, IM2COLW128(3), IM2COLW128(4), IM2COLW128(5)}, // im2col_w128
+ {NI, NI, NI, NI, NI, GATHER4(2)}, // tile_gather4
+ };
+ static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
+ "TMALoadModes must match number of rows in IDTable");
+
+ size_t mode = static_cast<size_t>(thisOp.getMode());
+ size_t dim = thisOp.getCoordinates().size();
+ llvm::Intrinsic::ID id = IDTable[mode][dim];
+ if (id == llvm::Intrinsic::not_intrinsic)
+ llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
+
+ return {id, std::move(args)};
+
+#undef GATHER4
+#undef IM2COLW128
+#undef IM2COLW
+#undef IM2COL
+#undef TILE
+#undef NI
----------------
grypp wrote:
Can we implement this without the macros? I found them quite hard to read.
https://github.com/llvm/llvm-project/pull/153464
More information about the Mlir-commits
mailing list