[Mlir-commits] [mlir] [MLIR][NVVM] Add missing rounding modes in fp16x2 conversions (PR #169005)

Guray Ozen llvmlistbot at llvm.org
Fri Nov 21 04:41:45 PST 2025


================
@@ -2740,30 +2747,100 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
     return TCGEN05_CP_2CTA(shape_mc, , is_2cta);                               \
   }()
 
-llvm::Intrinsic::ID ConvertF32x2ToF16x2Op::getIntrinsicID() {
-  bool hasRelu = getRelu();
-  bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
+NVVM::IDArgPair
+ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op,
+                                             LLVM::ModuleTranslation &mt,
+                                             llvm::IRBuilderBase &builder) {
+  static constexpr llvm::Intrinsic::ID rndRNIds[] = {
+      llvm::Intrinsic::nvvm_ff2f16x2_rn,
+      llvm::Intrinsic::nvvm_ff2f16x2_rn_relu,
+      llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite,
+      llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite,
+  };
+  static constexpr llvm::Intrinsic::ID rndRZIds[] = {
+      llvm::Intrinsic::nvvm_ff2f16x2_rz,
+      llvm::Intrinsic::nvvm_ff2f16x2_rz_relu,
+      llvm::Intrinsic::nvvm_ff2f16x2_rz_satfinite,
+      llvm::Intrinsic::nvvm_ff2f16x2_rz_relu_satfinite,
+  };
+  static constexpr llvm::Intrinsic::ID rndRSIds[] = {
+      llvm::Intrinsic::nvvm_ff2f16x2_rs,
+      llvm::Intrinsic::nvvm_ff2f16x2_rs_relu,
+      llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite,
+      llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite,
+  };
+
+  unsigned hasRelu = op.getRelu() ? 1 : 0;
+  unsigned hasSatFinite =
+      (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
+  // idx: bit-0 - relu
+  //      bit-1 - satfinite
+  unsigned idx = (hasSatFinite << 1) | hasRelu;
 
-  if (hasRelu && hasSatFinite)
-    return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite;
-  if (hasRelu)
-    return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu;
-  if (hasSatFinite)
-    return llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite;
-  return llvm::Intrinsic::nvvm_ff2f16x2_rs;
+  llvm::SmallVector<llvm::Value *> args;
+  args.push_back(mt.lookupValue(op.getSrcHi()));
+  args.push_back(mt.lookupValue(op.getSrcLo()));
+  if (op.getRandomBits())
+    args.push_back(mt.lookupValue(op.getRandomBits()));
+
+  switch (op.getRnd()) {
+  case FPRoundingMode::RN:
+    return {rndRNIds[idx], std::move(args)};
+  case FPRoundingMode::RZ:
+    return {rndRZIds[idx], std::move(args)};
+  case FPRoundingMode::RS:
+    return {rndRSIds[idx], std::move(args)};
+  default:
+    llvm_unreachable("Invalid rounding mode for ConvertF32x2ToF16x2Op");
+  }
 }
 
-llvm::Intrinsic::ID ConvertF32x2ToBF16x2Op::getIntrinsicID() {
-  bool hasRelu = getRelu();
-  bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
+NVVM::IDArgPair
+ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op,
+                                              LLVM::ModuleTranslation &mt,
+                                              llvm::IRBuilderBase &builder) {
+  static constexpr llvm::Intrinsic::ID rndRNIds[] = {
+      llvm::Intrinsic::nvvm_ff2bf16x2_rn,
+      llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu,
+      llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite,
+      llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite,
+  };
+  static constexpr llvm::Intrinsic::ID rndRZIds[] = {
+      llvm::Intrinsic::nvvm_ff2bf16x2_rz,
+      llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu,
+      llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite,
+      llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite,
+  };
+  static constexpr llvm::Intrinsic::ID rndRSIds[] = {
+      llvm::Intrinsic::nvvm_ff2bf16x2_rs,
+      llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu,
+      llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite,
+      llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite,
+  };
----------------
grypp wrote:

just thinking out loud - can we combine these two tables for b16 and f16: 
```
static constexpr llvm::Intrinsic::ID ff2x16IntrinsicSet
    [2 /*Fp16Kind*/]
    [3 /*RoundingMode*/]
    [4 /*PostOp*/] = {
  // ===== F16 =====
  {
    // RN
    {
      llvm::Intrinsic::nvvm_ff2f16x2_rn,
      llvm::Intrinsic::nvvm_ff2f16x2_rn_relu,
      llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite,
      llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite,
    },
     ....
}
```

and also write a selector function
```
inline llvm::Intrinsic::ID getIntrinsic(Fp16Kind kind,RoundingMode rnd, Post post) {
  return ff2x16IntrinsicSet[kind][rnd][post]
} 
```

Then you can select the intrinsic nicely:
```
llvm::Intrinsic::ID it = getIntrinsic(op.getType(), op.getRnd(), op.getMode());
```


https://github.com/llvm/llvm-project/pull/169005


More information about the Mlir-commits mailing list