[Mlir-commits] [mlir] [MLIR][NVVM] Update TMA tensor prefetch Op (PR #153464)

Guray Ozen llvmlistbot at llvm.org
Thu Aug 14 03:44:25 PDT 2025


================
@@ -1399,28 +1430,60 @@ mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
   return {id, std::move(args)};
 }
 
-llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
-                                                                bool isIm2Col) {
-  switch (tensorDims) {
-  case 1:
-    return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
-  case 2:
-    return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
-  case 3:
-    return isIm2Col
-               ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
-               : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
-  case 4:
-    return isIm2Col
-               ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
-               : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
-  case 5:
-    return isIm2Col
-               ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
-               : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
-  default:
-    llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp.");
-  }
+#define GET_TMA_OPCODE(op, mode, dim)                                          \
+  llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d
+
+mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
+    Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+  auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
+  llvm::SmallVector<llvm::Value *> args;
+
+  // Fill the Intrinsic Args
+  args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
+
+  for (auto v : thisOp.getCoordinates())
+    args.push_back(mt.lookupValue(v));
+  for (auto v : thisOp.getIm2colOffsets())
+    args.push_back(mt.lookupValue(v));
+
+  mlir::Value cacheHint = thisOp.getL2CacheHint();
+  const bool hasCacheHint = static_cast<bool>(cacheHint);
+  llvm::Value *i64Unused =
+      llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
+  args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
+  args.push_back(builder.getInt1(hasCacheHint));
+
+#define NI llvm::Intrinsic::not_intrinsic
+#define TILE(D) GET_TMA_OPCODE(prefetch, tile, D)
+#define IM2COL(D) GET_TMA_OPCODE(prefetch, im2col, D)
+#define IM2COLW(D) GET_TMA_OPCODE(prefetch, im2col_w, D)
+#define IM2COLW128(D) GET_TMA_OPCODE(prefetch, im2col_w_128, D)
+#define GATHER4(D) GET_TMA_OPCODE(prefetch, tile_gather4, D)
+
+  static constexpr llvm::Intrinsic::ID IDTable[][6] = {
+      {NI, TILE(1), TILE(2), TILE(3), TILE(4), TILE(5)},         // tile
+      {NI, NI, NI, IM2COL(3), IM2COL(4), IM2COL(5)},             // im2col
+      {NI, NI, NI, IM2COLW(3), IM2COLW(4), IM2COLW(5)},          // im2col_w
+      {NI, NI, NI, IM2COLW128(3), IM2COLW128(4), IM2COLW128(5)}, // im2col_w128
+      {NI, NI, NI, NI, NI, GATHER4(2)},                          // tile_gather4
+  };
+  static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
+                "TMALoadModes must match number of rows in IDTable");
+
+  size_t mode = static_cast<size_t>(thisOp.getMode());
+  size_t dim = thisOp.getCoordinates().size();
+  llvm::Intrinsic::ID id = IDTable[mode][dim];
+  if (id == llvm::Intrinsic::not_intrinsic)
+    llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
+
+  return {id, std::move(args)};
+
+#undef GATHER4
+#undef IM2COLW128
+#undef IM2COLW
+#undef IM2COL
+#undef TILE
+#undef NI
----------------
grypp wrote:

Can we implement this without the macros? I found them quite hard to read. 



https://github.com/llvm/llvm-project/pull/153464


More information about the Mlir-commits mailing list