[Mlir-commits] [mlir] [MLIR][NVVM] Update TMA Load Op (PR #156347)
Durgadoss R
llvmlistbot at llvm.org
Wed Sep 17 06:03:34 PDT 2025
================
@@ -1535,6 +1560,123 @@ mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
return {id, std::move(args)};
}
+bool CpAsyncBulkTensorGlobalToSharedClusterOp::getAsmValues(
+ RewriterBase &rewriter,
+ llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
+ &asmValues) {
+ // Add all the operands but not the attrs to the asmValues list.
+ // The attrs here are used to generate the right variants for
+ // intrinsics-lowering. So, we ignore them while generating inline-PTX.
+ for (auto val : getOperands())
+ asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
+
+ return false;
+}
+
+mlir::NVVM::IDArgPair
+CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(op);
+ const bool isCTAOnly = thisOp.getIsCTAOnly();
+ llvm::SmallVector<llvm::Value *> args;
+
+ // Fill the Intrinsic Args
+ args.push_back(mt.lookupValue(thisOp.getDstMem()));
+ args.push_back(mt.lookupValue(thisOp.getMbar()));
+ args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
+
+ // Coordinates and im2col-offsets
+ for (auto v : thisOp.getCoordinates())
+ args.push_back(mt.lookupValue(v));
+ for (auto v : thisOp.getIm2colOffsets())
+ args.push_back(mt.lookupValue(v));
+
+ // MulticastMask, if available
+ mlir::Value mcMask = thisOp.getMulticastMask();
+ const bool hasMC = static_cast<bool>(mcMask);
+ llvm::Value *i16Zero =
+ llvm::ConstantInt::get(llvm::Type::getInt16Ty(mt.getLLVMContext()), 0);
+
+ // CacheHint, if available
+ mlir::Value cacheHint = thisOp.getL2CacheHint();
+ const bool hasCacheHint = static_cast<bool>(cacheHint);
+ llvm::Value *i64Zero =
+ llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
+
+ // Flag argument CTAGroup
+ // CTA_1/2 is mapped to values 1 and 2 for the intrinsics.
+ // Hence, the +1 to getGroup().
+ const int32_t val =
+ thisOp.getGroup() ? (static_cast<int32_t>(*thisOp.getGroup()) + 1) : 0;
+ llvm::Value *cg =
+ llvm::ConstantInt::get(llvm::Type::getInt32Ty(mt.getLLVMContext()), val);
+
+ if (!isCTAOnly) {
+ // For shared::cluster, all the arguments that we build are applicable.
+ args.push_back(hasMC ? mt.lookupValue(mcMask) : i16Zero);
+ args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
+ args.push_back(builder.getInt1(hasMC));
+ args.push_back(builder.getInt1(hasCacheHint));
+ args.push_back(cg);
+ } else {
+ // For shared::cta, only cache-hint is applicable.
+ args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
+ args.push_back(builder.getInt1(hasCacheHint));
+ }
+
+ const unsigned NI = llvm::Intrinsic::not_intrinsic;
+ static constexpr llvm::Intrinsic::ID IDTable[][6] = {
----------------
durga4github wrote:
Done.
https://github.com/llvm/llvm-project/pull/156347
More information about the Mlir-commits
mailing list