[llvm] ConstantFolding: Do not fold fcmp of denormal without known mode (PR #115407)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 8 01:26:31 PST 2024


================
@@ -1298,47 +1300,110 @@ Constant *llvm::ConstantFoldBinaryOpOperands(unsigned Opcode, Constant *LHS,
   return ConstantFoldBinaryInstruction(Opcode, LHS, RHS);
 }
 
-Constant *llvm::FlushFPConstant(Constant *Operand, const Instruction *I,
-                                bool IsOutput) {
-  if (!I || !I->getParent() || !I->getFunction())
-    return Operand;
+static ConstantFP *flushDenormalConstant(Type *Ty, const APFloat &APF,
+                                         DenormalMode::DenormalModeKind Mode) {
+  switch (Mode) {
+  case DenormalMode::Dynamic:
+    return nullptr;
+  case DenormalMode::IEEE:
+    return ConstantFP::get(Ty->getContext(), APF);
+  case DenormalMode::PreserveSign:
+    return ConstantFP::get(
+        Ty->getContext(),
+        APFloat::getZero(APF.getSemantics(), APF.isNegative()));
+  case DenormalMode::PositiveZero:
+    return ConstantFP::get(Ty->getContext(),
+                           APFloat::getZero(APF.getSemantics(), false));
+  default:
+    break;
+  }
 
-  ConstantFP *CFP = dyn_cast<ConstantFP>(Operand);
-  if (!CFP)
-    return Operand;
+  llvm_unreachable("unknown denormal mode");
+}
+
+/// Return the denormal mode that can be assumed when executing a floating point
+/// operation at \p CtxI.
+static DenormalMode getInstrDenormalMode(const Instruction *CtxI, Type *Ty) {
+  if (!CtxI || !CtxI->getParent() || !CtxI->getFunction())
+    return DenormalMode::getDynamic();
+  return CtxI->getFunction()->getDenormalMode(Ty->getFltSemantics());
+}
 
+static ConstantFP *flushDenormalConstantFP(ConstantFP *CFP,
+                                           const Instruction *Inst,
+                                           bool IsOutput) {
   const APFloat &APF = CFP->getValueAPF();
-  // TODO: Should this canonicalize nans?
   if (!APF.isDenormal())
-    return Operand;
+    return CFP;
 
-  Type *Ty = CFP->getType();
-  DenormalMode DenormMode =
-      I->getFunction()->getDenormalMode(Ty->getFltSemantics());
-  DenormalMode::DenormalModeKind Mode =
-      IsOutput ? DenormMode.Output : DenormMode.Input;
-  switch (Mode) {
-  default:
-    llvm_unreachable("unknown denormal mode");
-  case DenormalMode::Dynamic:
-    return nullptr;
-  case DenormalMode::IEEE:
+  DenormalMode Mode = getInstrDenormalMode(Inst, CFP->getType());
+  return flushDenormalConstant(CFP->getType(), APF,
+                               IsOutput ? Mode.Output : Mode.Input);
+}
+
+Constant *llvm::FlushFPConstant(Constant *Operand, const Instruction *Inst,
+                                bool IsOutput) {
+  if (ConstantFP *CFP = dyn_cast<ConstantFP>(Operand))
+    return flushDenormalConstantFP(CFP, Inst, IsOutput);
+
+  if (isa<ConstantAggregateZero, UndefValue>(Operand))
     return Operand;
-  case DenormalMode::PreserveSign:
-    if (APF.isDenormal()) {
-      return ConstantFP::get(
-          Ty->getContext(),
-          APFloat::getZero(Ty->getFltSemantics(), APF.isNegative()));
+
+  Type *Ty = Operand->getType();
+  VectorType *VecTy = dyn_cast<VectorType>(Ty);
+  if (VecTy) {
+    if (auto *Splat = dyn_cast_or_null<ConstantFP>(Operand->getSplatValue())) {
+      ConstantFP *Folded = flushDenormalConstantFP(Splat, Inst, IsOutput);
+      if (!Folded)
+        return nullptr;
+      return ConstantVector::getSplat(VecTy->getElementCount(), Folded);
     }
-    return Operand;
-  case DenormalMode::PositiveZero:
-    if (APF.isDenormal()) {
-      return ConstantFP::get(Ty->getContext(),
-                             APFloat::getZero(Ty->getFltSemantics(), false));
+
+    Ty = VecTy->getElementType();
+  }
+
+  if (const auto *CV = dyn_cast<ConstantVector>(Operand)) {
+    SmallVector<Constant *, 16> NewElts;
+    for (unsigned i = 0, e = CV->getNumOperands(); i != e; ++i) {
+      Constant *Element = CV->getAggregateElement(i);
+      if (isa<UndefValue>(Element)) {
+        NewElts.push_back(Element);
+        continue;
+      }
+
+      ConstantFP *CFP = dyn_cast<ConstantFP>(Element);
+      if (!CFP)
+        return nullptr;
+
+      ConstantFP *Folded = flushDenormalConstantFP(CFP, Inst, IsOutput);
+      if (!Folded)
+        return nullptr;
+      NewElts.push_back(Folded);
     }
-    return Operand;
+
+    return ConstantVector::get(NewElts);
+  }
+
+  if (const auto *CDV = dyn_cast<ConstantDataVector>(Operand)) {
+    SmallVector<Constant *, 16> NewElts;
+    for (unsigned I = 0, E = CDV->getNumElements(); I < E; ++I) {
+      const APFloat &Elt = CDV->getElementAsAPFloat(I);
+      if (!Elt.isDenormal()) {
+        NewElts.push_back(ConstantFP::get(Ty, Elt));
+      } else {
+        DenormalMode Mode = getInstrDenormalMode(Inst, Ty);
+        ConstantFP *Folded =
+            flushDenormalConstant(Ty, Elt, IsOutput ? Mode.Output : Mode.Input);
+        if (!Folded)
+          return nullptr;
+        NewElts.push_back(Folded);
+      }
+    }
+
+    return ConstantVector::get(NewElts);
   }
-  return Operand;
+
+  return nullptr;
----------------
nikic wrote:

Side-note: We should probably have a generic method on Constant that does an element wise operation on it, so we don't have to repeat this scalar/splat/vector code everywhere.

https://github.com/llvm/llvm-project/pull/115407


More information about the llvm-commits mailing list