[llvm] [InstCombine] Fix for folding `select` into floating point binary operators. (PR #83200)

Paul Osmialowski via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 30 04:02:17 PDT 2024


================
@@ -536,19 +536,29 @@ Instruction *InstCombinerImpl::foldSelectIntoOp(SelectInst &SI, Value *TrueVal,
     // between 0, 1 and -1.
     const APInt *OOpC;
     bool OOpIsAPInt = match(OOp, m_APInt(OOpC));
-    if (!isa<Constant>(OOp) ||
-        (OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) {
-      Value *NewSel = Builder.CreateSelect(SI.getCondition(), Swapped ? C : OOp,
-                                           Swapped ? OOp : C, "", &SI);
-      if (isa<FPMathOperator>(&SI))
-        cast<Instruction>(NewSel)->setFastMathFlags(FMF);
-      NewSel->takeName(TVI);
-      BinaryOperator *BO =
-          BinaryOperator::Create(TVI->getOpcode(), FalseVal, NewSel);
-      BO->copyIRFlags(TVI);
-      return BO;
-    }
-    return nullptr;
+    if (isa<Constant>(OOp) &&
+        (!OOpIsAPInt || !isSelect01(C->getUniqueInteger(), *OOpC)))
+      return nullptr;
+
+    // If the false value is a NaN then we have that the floating point math
+    // operation in the transformed code may not preserve the exact NaN
+    // bit-pattern -- e.g. `fadd sNaN, 0.0 -> qNaN`.
+    // This makes the transformation incorrect since the original program would
+    // have preserved the exact NaN bit-pattern.
+    // Avoid the folding if the false value might be a NaN.
+    if (isa<FPMathOperator>(&SI) &&
+        !computeKnownFPClass(FalseVal, FMF, fcNan, &SI).isKnownNeverNaN())
----------------
pawosm-arm wrote:

I'm sorry, I'm sometimes lost here.
`There's a wrapper around computeKnownFPClass which takes an additional FastMathFlags`
Did you mean
`The wrapper around computeKnownFPClass called here takes an additional FastMathFlags`? That's how I see here, InstCombineSelect.cpp defines two wrappers around `computeKnownFPClass`:
```
  KnownFPClass computeKnownFPClass(Value *Val, FastMathFlags FMF,
                                   FPClassTest Interested = fcAllFlags,
                                   const Instruction *CtxI = nullptr,
                                   unsigned Depth = 0) const
```
and
```
  KnownFPClass computeKnownFPClass(Value *Val,
                                   FPClassTest Interested = fcAllFlags,
                                   const Instruction *CtxI = nullptr,
                                   unsigned Depth = 0) const
```
The one called here is the one which takes `FastMathFlags FMF`.


https://github.com/llvm/llvm-project/pull/83200


More information about the llvm-commits mailing list