[llvm] ConstantFolding: Do not fold fcmp of denormal without known mode (PR #115407)
Nikita Popov via llvm-commits
llvm-commits at lists.llvm.org
Fri Nov 8 01:26:31 PST 2024
================
@@ -1298,47 +1300,110 @@ Constant *llvm::ConstantFoldBinaryOpOperands(unsigned Opcode, Constant *LHS,
return ConstantFoldBinaryInstruction(Opcode, LHS, RHS);
}
-Constant *llvm::FlushFPConstant(Constant *Operand, const Instruction *I,
- bool IsOutput) {
- if (!I || !I->getParent() || !I->getFunction())
- return Operand;
+static ConstantFP *flushDenormalConstant(Type *Ty, const APFloat &APF,
+ DenormalMode::DenormalModeKind Mode) {
+ switch (Mode) {
+ case DenormalMode::Dynamic:
+ return nullptr;
+ case DenormalMode::IEEE:
+ return ConstantFP::get(Ty->getContext(), APF);
+ case DenormalMode::PreserveSign:
+ return ConstantFP::get(
+ Ty->getContext(),
+ APFloat::getZero(APF.getSemantics(), APF.isNegative()));
+ case DenormalMode::PositiveZero:
+ return ConstantFP::get(Ty->getContext(),
+ APFloat::getZero(APF.getSemantics(), false));
+ default:
+ break;
+ }
- ConstantFP *CFP = dyn_cast<ConstantFP>(Operand);
- if (!CFP)
- return Operand;
+ llvm_unreachable("unknown denormal mode");
+}
+
+/// Return the denormal mode that can be assumed when executing a floating point
+/// operation at \p CtxI.
+static DenormalMode getInstrDenormalMode(const Instruction *CtxI, Type *Ty) {
+ if (!CtxI || !CtxI->getParent() || !CtxI->getFunction())
+ return DenormalMode::getDynamic();
+ return CtxI->getFunction()->getDenormalMode(Ty->getFltSemantics());
+}
+static ConstantFP *flushDenormalConstantFP(ConstantFP *CFP,
+ const Instruction *Inst,
+ bool IsOutput) {
const APFloat &APF = CFP->getValueAPF();
- // TODO: Should this canonicalize nans?
if (!APF.isDenormal())
- return Operand;
+ return CFP;
- Type *Ty = CFP->getType();
- DenormalMode DenormMode =
- I->getFunction()->getDenormalMode(Ty->getFltSemantics());
- DenormalMode::DenormalModeKind Mode =
- IsOutput ? DenormMode.Output : DenormMode.Input;
- switch (Mode) {
- default:
- llvm_unreachable("unknown denormal mode");
- case DenormalMode::Dynamic:
- return nullptr;
- case DenormalMode::IEEE:
+ DenormalMode Mode = getInstrDenormalMode(Inst, CFP->getType());
+ return flushDenormalConstant(CFP->getType(), APF,
+ IsOutput ? Mode.Output : Mode.Input);
+}
+
+Constant *llvm::FlushFPConstant(Constant *Operand, const Instruction *Inst,
+ bool IsOutput) {
+ if (ConstantFP *CFP = dyn_cast<ConstantFP>(Operand))
+ return flushDenormalConstantFP(CFP, Inst, IsOutput);
+
+ if (isa<ConstantAggregateZero, UndefValue>(Operand))
return Operand;
- case DenormalMode::PreserveSign:
- if (APF.isDenormal()) {
- return ConstantFP::get(
- Ty->getContext(),
- APFloat::getZero(Ty->getFltSemantics(), APF.isNegative()));
+
+ Type *Ty = Operand->getType();
+ VectorType *VecTy = dyn_cast<VectorType>(Ty);
+ if (VecTy) {
+ if (auto *Splat = dyn_cast_or_null<ConstantFP>(Operand->getSplatValue())) {
+ ConstantFP *Folded = flushDenormalConstantFP(Splat, Inst, IsOutput);
+ if (!Folded)
+ return nullptr;
+ return ConstantVector::getSplat(VecTy->getElementCount(), Folded);
}
- return Operand;
- case DenormalMode::PositiveZero:
- if (APF.isDenormal()) {
- return ConstantFP::get(Ty->getContext(),
- APFloat::getZero(Ty->getFltSemantics(), false));
+
+ Ty = VecTy->getElementType();
+ }
+
+ if (const auto *CV = dyn_cast<ConstantVector>(Operand)) {
+ SmallVector<Constant *, 16> NewElts;
+ for (unsigned i = 0, e = CV->getNumOperands(); i != e; ++i) {
+ Constant *Element = CV->getAggregateElement(i);
+ if (isa<UndefValue>(Element)) {
+ NewElts.push_back(Element);
+ continue;
+ }
+
+ ConstantFP *CFP = dyn_cast<ConstantFP>(Element);
+ if (!CFP)
+ return nullptr;
+
+ ConstantFP *Folded = flushDenormalConstantFP(CFP, Inst, IsOutput);
+ if (!Folded)
+ return nullptr;
+ NewElts.push_back(Folded);
}
- return Operand;
+
+ return ConstantVector::get(NewElts);
+ }
+
+ if (const auto *CDV = dyn_cast<ConstantDataVector>(Operand)) {
+ SmallVector<Constant *, 16> NewElts;
+ for (unsigned I = 0, E = CDV->getNumElements(); I < E; ++I) {
+ const APFloat &Elt = CDV->getElementAsAPFloat(I);
+ if (!Elt.isDenormal()) {
+ NewElts.push_back(ConstantFP::get(Ty, Elt));
+ } else {
+ DenormalMode Mode = getInstrDenormalMode(Inst, Ty);
+ ConstantFP *Folded =
+ flushDenormalConstant(Ty, Elt, IsOutput ? Mode.Output : Mode.Input);
+ if (!Folded)
+ return nullptr;
+ NewElts.push_back(Folded);
+ }
+ }
+
+ return ConstantVector::get(NewElts);
}
- return Operand;
+
+ return nullptr;
----------------
nikic wrote:
Side-note: We should probably have a generic method on Constant that does an element wise operation on it, so we don't have to repeat this scalar/splat/vector code everywhere.
https://github.com/llvm/llvm-project/pull/115407
More information about the llvm-commits
mailing list