[llvm] [InstCombine] Add folds for `(fp_binop ({s|u}itofp x), ({s|u}itofp y))` (PR #82555)

via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 22 08:24:12 PST 2024


================
@@ -1401,6 +1401,171 @@ Value *InstCombinerImpl::dyn_castNegVal(Value *V) const {
   return nullptr;
 }
 
+// Try to fold:
+//    1) (fp_binop ({s|u}itofp x), ({s|u}itofp y))
+//        -> ({s|u}itofp (int_binop x, y))
+//    2) (fp_binop ({s|u}itofp x), FpC)
+//        -> ({s|u}itofp (int_binop x, (fpto{s|u}i FpC)))
+Instruction *InstCombinerImpl::foldFBinOpOfIntCasts(BinaryOperator &BO) {
+  Value *IntOps[2];
+  Constant *Op1FpC = nullptr;
+
+  // Check for:
+  //    1) (binop ({s|u}itofp x), ({s|u}itofp y))
+  //    2) (binop ({s|u}itofp x), FpC)
+  if (!match(BO.getOperand(0), m_SIToFP(m_Value(IntOps[0]))) &&
+      !match(BO.getOperand(0), m_UIToFP(m_Value(IntOps[0]))))
+    return nullptr;
+
+  if (!match(BO.getOperand(1), m_Constant(Op1FpC)) &&
+      !match(BO.getOperand(1), m_SIToFP(m_Value(IntOps[1]))) &&
+      !match(BO.getOperand(1), m_UIToFP(m_Value(IntOps[1]))))
+    return nullptr;
+
+
+  Type *FPTy = BO.getType();
+  Type *IntTy = IntOps[0]->getType();
+
+  // Do we have signed casts?
+  bool OpsFromSigned = isa<SIToFPInst>(BO.getOperand(0));
+
+
+  unsigned IntSz = IntTy->getScalarSizeInBits();
+  // This is the maximum number of inuse bits by the integer where the int -> fp
+  // casts are exact.
+  unsigned MaxRepresentableBits =
+      APFloat::semanticsPrecision(FPTy->getScalarType()->getFltSemantics());
+
+  // Cache KnownBits a bit to potentially save some analysis.
+  std::optional<KnownBits> OpsKnown[2];
+
+  // Preserve known number of leading bits. This can allow us to trivial nsw/nuw
+  // checks later on.
+  unsigned NumUsedLeadingBits[2] = {IntSz, IntSz};
+
+  auto IsNonZero = [&](unsigned OpNo) -> bool {
+    if (OpsKnown[OpNo].has_value() && OpsKnown[OpNo]->isNonZero())
+      return true;
+    return isKnownNonZero(IntOps[OpNo], SQ.DL);
+  };
+
+  auto IsNonNeg = [&](unsigned OpNo) -> bool {
+    if (OpsKnown[OpNo].has_value() && OpsKnown[OpNo]->isNonNegative())
+      return true;
+    return isKnownNonNegative(IntOps[OpNo], SQ);
+  };
+
+  // Check if we know for certain that ({s|u}itofp op) is exact.
+  auto IsValidPromotion = [&](unsigned OpNo) -> bool {
+    // If fp precision >= bitwidth(op) then its exact.
+    if (MaxRepresentableBits >= IntSz)
+      ;
+    // Otherwise if its signed cast check that fp precisions >= bitwidth(op) -
+    // numSignBits(op).
+    else if (OpsFromSigned)
+      NumUsedLeadingBits[OpNo] = IntSz - ComputeNumSignBits(IntOps[OpNo]);
+    // Finally for unsigned check that fp precision >= bitwidth(op) -
+    // numLeadingZeros(op).
+    else {
+      if (!OpsKnown[OpNo].has_value())
+        OpsKnown[OpNo] = computeKnownBits(IntOps[OpNo], /*Depth*/ 0, &BO);
+      NumUsedLeadingBits[OpNo] = IntSz - OpsKnown[OpNo]->countMinLeadingZeros();
+    }
+    // NB: We could also check if op is known to be a power of 2 or zero (which
+    // will always be representable). Its unlikely, however, that is we are
+    // unable to bound op in any way we will be able to pass the overflow checks
+    // later on.
+
+    if (MaxRepresentableBits < NumUsedLeadingBits[OpNo])
+      return false;
+    // Signed + Mul also requires that op is non-zero to avoid -0 cases.
+    return (OpsFromSigned && BO.getOpcode() == Instruction::FMul)
+               ? IsNonZero(OpNo)
+               : true;
+
+  };
+
+  // If we have a constant rhs, see if we can losslessly convert it to an int.
+  if (Op1FpC != nullptr) {
+    Constant *Op1IntC = ConstantFoldCastOperand(
+        OpsFromSigned ? Instruction::FPToSI : Instruction::FPToUI, Op1FpC,
+        IntTy, DL);
+    if (Op1IntC == nullptr)
+      return nullptr;
+    if (ConstantFoldCastOperand(OpsFromSigned ? Instruction::SIToFP
+                                              : Instruction::UIToFP,
+                                Op1IntC, FPTy, DL) != Op1FpC)
+      return nullptr;
+
+    // First try to keep sign of cast the same.
+    IntOps[1] = Op1IntC;
+  }
+
+  // Ensure lhs/rhs integer types match.
+  if (IntTy != IntOps[1]->getType())
+    return nullptr;
+
+
+  if (Op1FpC == nullptr) {
+    if (OpsFromSigned != isa<SIToFPInst>(BO.getOperand(1))) {
+      // If we have a signed + unsigned, see if we can treat both as signed
+      // (uitofp nneg x) == (sitofp nneg x).
+      if (OpsFromSigned ? !IsNonNeg(1) : !IsNonNeg(0))
+        return nullptr;
+      OpsFromSigned = true;
+    }
+    if (!IsValidPromotion(1))
+      return nullptr;
+  }
+  if (!IsValidPromotion(0))
+    return nullptr;
+
+  // Final we check if the integer version of the binop will not overflow.
+  BinaryOperator::BinaryOps IntOpc;
+  // Because of the precision check, we can often rule out overflows.
+  bool NeedsOverflowCheck = true;
+  // Try to conservatively rule out overflow based on the already done precision
+  // checks.
+  unsigned OverflowMaxOutputBits = OpsFromSigned ? 2 : 1;
+  unsigned OverflowMaxCurBits =
+      std::max(NumUsedLeadingBits[0], NumUsedLeadingBits[1]);
+  bool OutputSigned = OpsFromSigned;
+  switch (BO.getOpcode()) {
+  case Instruction::FAdd:
+    IntOpc = Instruction::Add;
+    OverflowMaxOutputBits += OverflowMaxCurBits;
+    break;
+  case Instruction::FSub:
+    IntOpc = Instruction::Sub;
+    OverflowMaxOutputBits += OverflowMaxCurBits;
+    break;
+  case Instruction::FMul:
+    IntOpc = Instruction::Mul;
+    OverflowMaxOutputBits += OverflowMaxCurBits * 2;
+    break;
+  default:
+    llvm_unreachable("Unsupported binop");
+  }
+  // The precision check may have already ruled out overflow.
+  if (OverflowMaxOutputBits < IntSz) {
+    NeedsOverflowCheck = false;
+    // We can bound unsigned overflow from sub to in range signed value (this is
+    // what allows us to avoid the overflow check for sub).
+    if (IntOpc == Instruction::Sub)
+      OutputSigned = true;
+  }
+
+  // Precision check did not rule out overflow, so need to check.
+  if (NeedsOverflowCheck &&
+      !willNotOverflow(IntOpc, IntOps[0], IntOps[1], BO, OutputSigned))
+    return nullptr;
+
+  Value *IntBinOp = Builder.CreateBinOp(IntOpc, IntOps[0], IntOps[1]);
----------------
goldsteinn wrote:

Sure, added the flag we checked for as default. We could also get the other one in some cases but would rather let existing logic handle that. We don't lose any information about `nsw`/`nuw` during this transform.

https://github.com/llvm/llvm-project/pull/82555


More information about the llvm-commits mailing list