[Mlir-commits] [mlir] [MLIR][NVVM] Update TMA Load Op (PR #156347)

Durgadoss R llvmlistbot at llvm.org
Wed Sep 17 06:03:34 PDT 2025


================
@@ -1535,6 +1560,123 @@ mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
   return {id, std::move(args)};
 }
 
+bool CpAsyncBulkTensorGlobalToSharedClusterOp::getAsmValues(
+    RewriterBase &rewriter,
+    llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
+        &asmValues) {
+  // Add all the operands but not the attrs to the asmValues list.
+  // The attrs here are used to generate the right variants for
+  // intrinsics-lowering. So, we ignore them while generating inline-PTX.
+  for (auto val : getOperands())
+    asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
+
+  return false;
+}
+
+mlir::NVVM::IDArgPair
+CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
+    Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+  auto thisOp = cast<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(op);
+  const bool isCTAOnly = thisOp.getIsCTAOnly();
+  llvm::SmallVector<llvm::Value *> args;
+
+  // Fill the Intrinsic Args
+  args.push_back(mt.lookupValue(thisOp.getDstMem()));
+  args.push_back(mt.lookupValue(thisOp.getMbar()));
+  args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
+
+  // Coordinates and im2col-offsets
+  for (auto v : thisOp.getCoordinates())
+    args.push_back(mt.lookupValue(v));
+  for (auto v : thisOp.getIm2colOffsets())
+    args.push_back(mt.lookupValue(v));
+
+  // MulticastMask, if available
+  mlir::Value mcMask = thisOp.getMulticastMask();
+  const bool hasMC = static_cast<bool>(mcMask);
+  llvm::Value *i16Zero =
+      llvm::ConstantInt::get(llvm::Type::getInt16Ty(mt.getLLVMContext()), 0);
+
+  // CacheHint, if available
+  mlir::Value cacheHint = thisOp.getL2CacheHint();
+  const bool hasCacheHint = static_cast<bool>(cacheHint);
+  llvm::Value *i64Zero =
+      llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
+
+  // Flag argument CTAGroup
+  // CTA_1/2 is mapped to values 1 and 2 for the intrinsics.
+  // Hence, the +1 to getGroup().
+  const int32_t val =
+      thisOp.getGroup() ? (static_cast<int32_t>(*thisOp.getGroup()) + 1) : 0;
+  llvm::Value *cg =
+      llvm::ConstantInt::get(llvm::Type::getInt32Ty(mt.getLLVMContext()), val);
+
+  if (!isCTAOnly) {
+    // For shared::cluster, all the arguments that we build are applicable.
+    args.push_back(hasMC ? mt.lookupValue(mcMask) : i16Zero);
+    args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
+    args.push_back(builder.getInt1(hasMC));
+    args.push_back(builder.getInt1(hasCacheHint));
+    args.push_back(cg);
+  } else {
+    // For shared::cta, only cache-hint is applicable.
+    args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
+    args.push_back(builder.getInt1(hasCacheHint));
+  }
+
+  const unsigned NI = llvm::Intrinsic::not_intrinsic;
+  static constexpr llvm::Intrinsic::ID IDTable[][6] = {
----------------
durga4github wrote:

Done.

https://github.com/llvm/llvm-project/pull/156347


More information about the Mlir-commits mailing list