[llvm] [InstCombine] Add folds for `(fp_binop ({s|u}itofp x), ({s|u}itofp y))` (PR #82555)

via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 22 08:19:59 PST 2024


================
@@ -1401,6 +1401,171 @@ Value *InstCombinerImpl::dyn_castNegVal(Value *V) const {
   return nullptr;
 }
 
+// Try to fold:
+//    1) (fp_binop ({s|u}itofp x), ({s|u}itofp y))
+//        -> ({s|u}itofp (int_binop x, y))
+//    2) (fp_binop ({s|u}itofp x), FpC)
+//        -> ({s|u}itofp (int_binop x, (fpto{s|u}i FpC)))
+Instruction *InstCombinerImpl::foldFBinOpOfIntCasts(BinaryOperator &BO) {
+  Value *IntOps[2];
+  Constant *Op1FpC = nullptr;
+
+  // Check for:
+  //    1) (binop ({s|u}itofp x), ({s|u}itofp y))
+  //    2) (binop ({s|u}itofp x), FpC)
+  if (!match(BO.getOperand(0), m_SIToFP(m_Value(IntOps[0]))) &&
+      !match(BO.getOperand(0), m_UIToFP(m_Value(IntOps[0]))))
+    return nullptr;
+
+  if (!match(BO.getOperand(1), m_Constant(Op1FpC)) &&
+      !match(BO.getOperand(1), m_SIToFP(m_Value(IntOps[1]))) &&
+      !match(BO.getOperand(1), m_UIToFP(m_Value(IntOps[1]))))
+    return nullptr;
+
+
+  Type *FPTy = BO.getType();
+  Type *IntTy = IntOps[0]->getType();
+
+  // Do we have signed casts?
+  bool OpsFromSigned = isa<SIToFPInst>(BO.getOperand(0));
+
+
+  unsigned IntSz = IntTy->getScalarSizeInBits();
+  // This is the maximum number of inuse bits by the integer where the int -> fp
+  // casts are exact.
+  unsigned MaxRepresentableBits =
+      APFloat::semanticsPrecision(FPTy->getScalarType()->getFltSemantics());
+
+  // Cache KnownBits a bit to potentially save some analysis.
+  std::optional<KnownBits> OpsKnown[2];
+
+  // Preserve known number of leading bits. This can allow us to trivial nsw/nuw
+  // checks later on.
+  unsigned NumUsedLeadingBits[2] = {IntSz, IntSz};
+
+  auto IsNonZero = [&](unsigned OpNo) -> bool {
+    if (OpsKnown[OpNo].has_value() && OpsKnown[OpNo]->isNonZero())
+      return true;
+    return isKnownNonZero(IntOps[OpNo], SQ.DL);
+  };
+
+  auto IsNonNeg = [&](unsigned OpNo) -> bool {
+    if (OpsKnown[OpNo].has_value() && OpsKnown[OpNo]->isNonNegative())
+      return true;
+    return isKnownNonNegative(IntOps[OpNo], SQ);
+  };
+
+  // Check if we know for certain that ({s|u}itofp op) is exact.
+  auto IsValidPromotion = [&](unsigned OpNo) -> bool {
+    // If fp precision >= bitwidth(op) then its exact.
+    if (MaxRepresentableBits >= IntSz)
+      ;
+    // Otherwise if its signed cast check that fp precisions >= bitwidth(op) -
+    // numSignBits(op).
+    else if (OpsFromSigned)
+      NumUsedLeadingBits[OpNo] = IntSz - ComputeNumSignBits(IntOps[OpNo]);
+    // Finally for unsigned check that fp precision >= bitwidth(op) -
+    // numLeadingZeros(op).
+    else {
+      if (!OpsKnown[OpNo].has_value())
+        OpsKnown[OpNo] = computeKnownBits(IntOps[OpNo], /*Depth*/ 0, &BO);
+      NumUsedLeadingBits[OpNo] = IntSz - OpsKnown[OpNo]->countMinLeadingZeros();
+    }
+    // NB: We could also check if op is known to be a power of 2 or zero (which
+    // will always be representable). Its unlikely, however, that is we are
+    // unable to bound op in any way we will be able to pass the overflow checks
+    // later on.
+
+    if (MaxRepresentableBits < NumUsedLeadingBits[OpNo])
+      return false;
+    // Signed + Mul also requires that op is non-zero to avoid -0 cases.
+    return (OpsFromSigned && BO.getOpcode() == Instruction::FMul)
+               ? IsNonZero(OpNo)
+               : true;
+
+  };
+
+  // If we have a constant rhs, see if we can losslessly convert it to an int.
+  if (Op1FpC != nullptr) {
+    Constant *Op1IntC = ConstantFoldCastOperand(
+        OpsFromSigned ? Instruction::FPToSI : Instruction::FPToUI, Op1FpC,
+        IntTy, DL);
+    if (Op1IntC == nullptr)
+      return nullptr;
+    if (ConstantFoldCastOperand(OpsFromSigned ? Instruction::SIToFP
+                                              : Instruction::UIToFP,
+                                Op1IntC, FPTy, DL) != Op1FpC)
+      return nullptr;
+
+    // First try to keep sign of cast the same.
+    IntOps[1] = Op1IntC;
+  }
+
+  // Ensure lhs/rhs integer types match.
+  if (IntTy != IntOps[1]->getType())
+    return nullptr;
+
+
+  if (Op1FpC == nullptr) {
+    if (OpsFromSigned != isa<SIToFPInst>(BO.getOperand(1))) {
+      // If we have a signed + unsigned, see if we can treat both as signed
+      // (uitofp nneg x) == (sitofp nneg x).
+      if (OpsFromSigned ? !IsNonNeg(1) : !IsNonNeg(0))
----------------
goldsteinn wrote:

Its def more succinct, but also think its less clear. Okay with me keeping as is?

https://github.com/llvm/llvm-project/pull/82555


More information about the llvm-commits mailing list