[Mlir-commits] [mlir] [MLIR][NVVM] Add missing rounding modes in fp16x2 conversions (PR #169005)
Srinivasa Ravi
llvmlistbot at llvm.org
Fri Nov 21 04:50:06 PST 2025
================
@@ -2740,30 +2747,100 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
}()
-llvm::Intrinsic::ID ConvertF32x2ToF16x2Op::getIntrinsicID() {
- bool hasRelu = getRelu();
- bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
+NVVM::IDArgPair
+ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ static constexpr llvm::Intrinsic::ID rndRNIds[] = {
+ llvm::Intrinsic::nvvm_ff2f16x2_rn,
+ llvm::Intrinsic::nvvm_ff2f16x2_rn_relu,
+ llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite,
+ llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite,
+ };
+ static constexpr llvm::Intrinsic::ID rndRZIds[] = {
+ llvm::Intrinsic::nvvm_ff2f16x2_rz,
+ llvm::Intrinsic::nvvm_ff2f16x2_rz_relu,
+ llvm::Intrinsic::nvvm_ff2f16x2_rz_satfinite,
+ llvm::Intrinsic::nvvm_ff2f16x2_rz_relu_satfinite,
+ };
+ static constexpr llvm::Intrinsic::ID rndRSIds[] = {
+ llvm::Intrinsic::nvvm_ff2f16x2_rs,
+ llvm::Intrinsic::nvvm_ff2f16x2_rs_relu,
+ llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite,
+ llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite,
+ };
+
+ unsigned hasRelu = op.getRelu() ? 1 : 0;
+ unsigned hasSatFinite =
+ (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
+ // idx: bit-0 - relu
+ // bit-1 - satfinite
+ unsigned idx = (hasSatFinite << 1) | hasRelu;
- if (hasRelu && hasSatFinite)
- return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite;
- if (hasRelu)
- return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu;
- if (hasSatFinite)
- return llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite;
- return llvm::Intrinsic::nvvm_ff2f16x2_rs;
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(op.getSrcHi()));
+ args.push_back(mt.lookupValue(op.getSrcLo()));
+ if (op.getRandomBits())
+ args.push_back(mt.lookupValue(op.getRandomBits()));
+
+ switch (op.getRnd()) {
+ case FPRoundingMode::RN:
+ return {rndRNIds[idx], std::move(args)};
+ case FPRoundingMode::RZ:
+ return {rndRZIds[idx], std::move(args)};
+ case FPRoundingMode::RS:
+ return {rndRSIds[idx], std::move(args)};
+ default:
+ llvm_unreachable("Invalid rounding mode for ConvertF32x2ToF16x2Op");
+ }
}
-llvm::Intrinsic::ID ConvertF32x2ToBF16x2Op::getIntrinsicID() {
- bool hasRelu = getRelu();
- bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
+NVVM::IDArgPair
+ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ static constexpr llvm::Intrinsic::ID rndRNIds[] = {
+ llvm::Intrinsic::nvvm_ff2bf16x2_rn,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite,
+ };
+ static constexpr llvm::Intrinsic::ID rndRZIds[] = {
+ llvm::Intrinsic::nvvm_ff2bf16x2_rz,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite,
+ };
+ static constexpr llvm::Intrinsic::ID rndRSIds[] = {
+ llvm::Intrinsic::nvvm_ff2bf16x2_rs,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite,
+ };
----------------
Wolfram70 wrote:
I did consider having a single table initially but went with separate ones because the valid rounding modes `rn`, `rz`, and `rp` are not consecutive in the `FPRoundingMode` enum so we'd have to do some sort of mapping which didn't look very nice.
https://github.com/llvm/llvm-project/pull/169005
More information about the Mlir-commits
mailing list