[llvm] [AMDGPU] Implement IR expansion for frem instruction (PR #130988)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Sun Apr 13 00:31:37 PDT 2025


================
@@ -37,6 +49,342 @@ static cl::opt<unsigned>
                         cl::desc("fp convert instructions on integers with "
                                  "more than <N> bits are expanded."));
 
+namespace {
+/// This class implements a precise expansion of the frem instruction.
+/// The generated code is based on the fmod implementation in the AMD device
+/// libs.
+class FRemExpander {
+  /// The IRBuilder to use for the expansion.
+  IRBuilder<> &B;
+
+  /// Floating point type of the return value and the arguments of the FRem
+  /// instructions that should be expanded.
+  Type *FremTy;
+
+  /// Floating point type to use for the computation.  This may be
+  /// wider than the \p FremTy.
+  Type *ComputeFpTy;
+
+  /// Integer type used to hold the exponents returned by frexp.
+  Type *ExTy;
+
+  /// How many bits of the quotient to compute per iteration of the
+  /// algorithm, stored as a value of type \p ExTy.
+  Value *Bits;
+
+  /// Constant 1 of type \p ExTy.
+  Value *One;
+
+public:
+  static std::optional<FRemExpander> create(IRBuilder<> &B, Type *Ty) {
+    // TODO The expansion should work for other types as well, but
+    // this would require additional testing.
+    if (!Ty->isIEEELikeFPTy() || Ty->isBFloatTy() || Ty->isFP128Ty())
+      return std::nullopt;
+
+    // The type to use for the computation of the remainder. This may be
+    // wider than the input/result type which affects the ...
+    Type *ComputeTy = Ty;
+    // ... maximum number of iterations of the remainder computation loop
+    // to use. This value is for the case in which the computation
+    // uses the same input/result type.
+    unsigned MaxIter = 2;
+
+    if (Ty->is16bitFPTy()) {
+      // Use the wider type and less iterations.
+      ComputeTy = B.getFloatTy();
+      MaxIter = 1;
+    }
+
+    unsigned Precision =
+        llvm::APFloat::semanticsPrecision(Ty->getFltSemantics());
+    return FRemExpander{B, Ty, Precision / MaxIter, ComputeTy};
+  }
+
+  /// Build the FRem expansion for the numerator \p X and the
+  /// denumerator \p Y using the builder \p B.  The type of X and Y
+  /// must match the type for which the class instance has been
+  /// created. The code will be generated at the insertion point of \p
+  /// B and the insertion point will be reset at exit.
+  Value *buildFRem(Value *X, Value *Y, std::optional<SimplifyQuery> &SQ) const;
+
+private:
+  FRemExpander(IRBuilder<> &B, Type *FremTy, unsigned Bits, Type *ComputeFpTy)
+      : B(B), FremTy(FremTy), ComputeFpTy(ComputeFpTy), ExTy(B.getInt32Ty()),
+        Bits(ConstantInt::get(ExTy, Bits)), One(ConstantInt::get(ExTy, 1)) {};
+
+  Value *createRcp(Value *V, const Twine &Name) const {
+    // Leave it to later optimizations to turn this into an rcp
+    // instruction if available.
+    return B.CreateFDiv(ConstantFP::get(ComputeFpTy, 1.0), V, Name);
+  }
+
+  // Helper function to build the UPDATE_AX code which is common to the
+  // loop body and the "final iteration".
+  Value *buildUpdateAx(Value *Ax, Value *Ay, Value *Ayinv) const {
+    // Build:
+    //   float q = rint(ax * ayinv);
+    //   ax = fma(-q, ay, ax);
+    //   int clt = ax < 0.0f;
+    //   float axp = ax + ay;
+    //   ax = clt ? axp : ax;
+    Value *Q = B.CreateUnaryIntrinsic(Intrinsic::rint, B.CreateFMul(Ax, Ayinv),
+                                      {}, "q");
+    Value *AxUpdate = B.CreateFMA(B.CreateFNeg(Q), Ay, Ax, {}, "ax");
+    Value *Clt = B.CreateFCmp(CmpInst::FCMP_OLT, AxUpdate,
+                              ConstantFP::get(ComputeFpTy, 0.0), "clt");
+    Value *Axp = B.CreateFAdd(AxUpdate, Ay, "axp");
+    AxUpdate = B.CreateSelect(Clt, Axp, AxUpdate, "ax");
+
+    return AxUpdate;
+  }
+
+  /// Build code to extract the exponent and mantissa of \p Src.
+  /// Return the exponent minus one for use as a loop bound and
+  /// the mantissa taken to the given \p NewExp power.
+  std::pair<Value *, Value *> buildExpAndPower(Value *Src, Value *NewExp,
+                                               const Twine &ExName,
+                                               const Twine &PowName) const {
+    // Build:
+    //   ExName = frexp_exp(Src) - 1;
+    //   PowName = fldexp(frexp_mant(ExName), NewExp);
+    Type *Ty = Src->getType();
+    Type *ExTy = B.getInt32Ty();
+    Value *Frexp = B.CreateIntrinsic(Intrinsic::frexp, {Ty, ExTy}, Src);
+    Value *Mant = B.CreateExtractValue(Frexp, {0});
+    Value *Exp = B.CreateExtractValue(Frexp, {1});
+
+    Exp = B.CreateSub(Exp, One, ExName);
+    Value *Pow = B.CreateLdexp(Mant, NewExp, {}, PowName);
+
+    return {Pow, Exp};
+  }
+
+  /// Build the main computation of the remainder for the case in which
+  /// Ax > Ay, where Ax = |X|, Ay = |Y|, and X is the numerator and Y the
+  /// denumerator. Add the incoming edge from the computation result
+  /// to \p RetPhi.
+  void buildRemainderComputation(Value *AxInitial, Value *AyInitial, Value *X,
+                                 PHINode *RetPhi, FastMathFlags FMF) const {
+    IRBuilder<>::FastMathFlagGuard Guard(B);
+    B.setFastMathFlags(FMF);
+
+    // Build:
+    // ex = frexp_exp(ax) - 1;
+    // ax = fldexp(frexp_mant(ax), bits);
+    // ey = frexp_exp(ay) - 1;
+    // ay = fledxp(frexp_mant(ay), 1);
+    auto [Ax, Ex] = buildExpAndPower(AxInitial, Bits, "ex", "ax");
+    auto [Ay, Ey] = buildExpAndPower(AyInitial, One, "ey", "ay");
+
+    // Build:
+    //   int nb = ex - ey;
+    //   float ayinv = 1.0/ay;
+    Value *Nb = B.CreateSub(Ex, Ey, "nb");
+    Value *Ayinv = createRcp(Ay, "ayinv");
+
+    // Build: while (nb > bits)
+    BasicBlock *PreheaderBB = B.GetInsertBlock();
+    Function *Fun = PreheaderBB->getParent();
+    auto *LoopBB = BasicBlock::Create(B.getContext(), "frem.loop_body", Fun);
+    auto *ExitBB = BasicBlock::Create(B.getContext(), "frem.loop_exit", Fun);
+
+    B.CreateCondBr(B.CreateICmp(CmpInst::ICMP_SGT, Nb, Bits), LoopBB, ExitBB);
+
+    // Build loop body:
+    //   UPDATE_AX
+    //   ax = fldexp(ax, bits);
+    //   nb -= bits;
+    // One iteration of the loop is factored out.  The code shared by
+    // the loop and this "iteration" is denoted by UPDATE_AX.
+    B.SetInsertPoint(LoopBB);
+    auto *NbIv = B.CreatePHI(Nb->getType(), 2, "nb_iv");
+    NbIv->addIncoming(Nb, PreheaderBB);
+
+    auto *AxPhi = B.CreatePHI(ComputeFpTy, 2, "ax_loop_phi");
+    AxPhi->addIncoming(Ax, PreheaderBB);
+
+    Value *AxPhiUpdate = buildUpdateAx(AxPhi, Ay, Ayinv);
+    AxPhiUpdate = B.CreateLdexp(AxPhiUpdate, Bits, {}, "ax_update");
+    AxPhi->addIncoming(AxPhiUpdate, LoopBB);
+    NbIv->addIncoming(B.CreateSub(NbIv, Bits, "nb_update"), LoopBB);
+
+    B.CreateCondBr(B.CreateICmp(CmpInst::ICMP_SGT, NbIv, Bits), LoopBB, ExitBB);
+
+    // Build final iteration
+    //   ax = fldexp(ax, nb - bits + 1);
+    //   UPDATE_AX
+    B.SetInsertPoint(ExitBB);
+
+    auto *AxPhiExit = B.CreatePHI(ComputeFpTy, 2, "ax_exit_phi");
+    AxPhiExit->addIncoming(Ax, PreheaderBB);
+    AxPhiExit->addIncoming(AxPhi, LoopBB);
+    auto *NbExitPhi = B.CreatePHI(Nb->getType(), 2, "nb_exit_phi");
+    NbExitPhi->addIncoming(NbIv, LoopBB);
+    NbExitPhi->addIncoming(Nb, PreheaderBB);
+
+    Value *AxFinal = B.CreateLdexp(
+        AxPhiExit, B.CreateAdd(B.CreateSub(NbExitPhi, Bits), One), {}, "ax");
+    AxFinal = buildUpdateAx(AxFinal, Ay, Ayinv);
+
+    // Build:
+    //    ax = fldexp(ax, ey);
+    //    ret = copysign(ax,x);
+    AxFinal = B.CreateLdexp(AxFinal, Ey, {}, "ax");
+    if (ComputeFpTy != FremTy)
+      AxFinal = B.CreateFPTrunc(AxFinal, FremTy);
+    Value *Ret = B.CreateCopySign(AxFinal, X);
+
+    RetPhi->addIncoming(Ret, ExitBB);
+  }
+
+  /// Build the else-branch of the conditional in the FRem
+  /// expansion, i.e. the case in wich Ax <= Ay, where Ax = |X|, Ay
+  /// = |Y|, and X is the numerator and Y the denumerator. Add the
+  /// incoming edge from the result to \p RetPhi.
+  void buildElseBranch(Value *Ax, Value *Ay, Value *X, PHINode *RetPhi) const {
+    // Build:
+    // ret = ax == ay ? copysign(0.0f, x) : x;
+    Value *ZeroWithXSign = B.CreateCopySign(ConstantFP::get(FremTy, 0.0), X);
+    Value *Ret = B.CreateSelect(B.CreateFCmpOEQ(Ax, Ay), ZeroWithXSign, X);
+
+    RetPhi->addIncoming(Ret, B.GetInsertBlock());
+  }
+
+  /// Return a value that is NaN if one of the corner cases concerning
+  /// the inputs \p X and \p Y is detected, and \p Ret otherwise.
+  Value *handleInputCornerCases(Value *Ret, Value *X, Value *Y,
+                                std::optional<SimplifyQuery> &SQ,
+                                bool NoInfs) const {
+    // Build:
+    //   ret = (y == 0.0f || isnan(y)) ? QNAN : ret;
+    //   ret = isfinite(x) ? ret : QNAN;
+    Value *Nan = ConstantFP::getQNaN(FremTy);
+    Ret = B.CreateSelect(B.CreateFCmpUEQ(Y, ConstantFP::get(FremTy, 0.0)), Nan,
+                         Ret);
+    Value *XFinite =
+        NoInfs || (SQ && isKnownNeverInfinity(X, 0, *SQ))
+            ? B.getTrue()
+            : B.CreateFCmpULT(B.CreateUnaryIntrinsic(Intrinsic::fabs, X),
+                              ConstantFP::getInfinity(FremTy));
+    Ret = B.CreateSelect(XFinite, Ret, Nan);
+
+    return Ret;
+  }
+};
+
+Value *FRemExpander::buildFRem(Value *X, Value *Y,
+                               std::optional<SimplifyQuery> &SQ) const {
+  assert(X->getType() == FremTy && Y->getType() == FremTy);
+
+  FastMathFlags FMF = B.getFastMathFlags();
+
+  // This function generates the following code structure:
+  //   if (abs(x) > abs(y))
+  //   { ret = compute remainder }
+  //   else
+  //   { ret = x or 0 with sign of x }
+  //   Adjust ret to NaN/inf in input
+  //   return ret
+  Value *Ax = B.CreateUnaryIntrinsic(Intrinsic::fabs, X, {}, "ax");
+  Value *Ay = B.CreateUnaryIntrinsic(Intrinsic::fabs, Y, {}, "ay");
+  if (ComputeFpTy != X->getType()) {
+    Ax = B.CreateFPExt(Ax, ComputeFpTy, "ax");
+    Ay = B.CreateFPExt(Ay, ComputeFpTy, "ay");
+  }
+  Value *AxAyCmp = B.CreateFCmpOGT(Ax, Ay);
+
+  PHINode *RetPhi = B.CreatePHI(FremTy, 2, "ret");
+  Value *Ret = RetPhi;
+
+  // We would return NaN in all corner cases handled here.
+  // Hence, if NaNs are excluded, keep the result as it is.
+  if (!FMF.noNaNs())
+    Ret = handleInputCornerCases(Ret, X, Y, SQ, FMF.noInfs());
+
+  Function *Fun = B.GetInsertBlock()->getParent();
+  auto *ThenBB = BasicBlock::Create(B.getContext(), "frem.compute", Fun);
+  auto *ElseBB = BasicBlock::Create(B.getContext(), "frem.else", Fun);
+  SplitBlockAndInsertIfThenElse(AxAyCmp, RetPhi, &ThenBB, &ElseBB);
+
+  auto SavedInsertPt = B.GetInsertPoint();
+
+  // Build remainder computation for "then" branch
+  //
+  // The ordered comparison ensures that ax and ay are not NaNs
+  // in the then-branch. Furthermore, y cannot be an infinity and the
+  // check at the end of the function ensures that the result will not
+  // be used if x is an infinity.
+  FastMathFlags ComputeFMF = FMF;
+  ComputeFMF.setNoInfs();
+  ComputeFMF.setNoNaNs();
+
+  B.SetInsertPoint(ThenBB);
+  buildRemainderComputation(Ax, Ay, X, RetPhi, FMF);
+  B.CreateBr(RetPhi->getParent());
+
+  // Build "else"-branch
+  B.SetInsertPoint(ElseBB);
+  buildElseBranch(Ax, Ay, X, RetPhi);
+  B.CreateBr(RetPhi->getParent());
+
+  B.SetInsertPoint(SavedInsertPt);
+
+  return Ret;
+}
+} // namespace
+
+static bool expandFRem(BinaryOperator &I, std::optional<SimplifyQuery> &SQ) {
----------------
arsenm wrote:

Eventually should handle strictfp too. I assume this bakes in the rounding mode somewhere 

https://github.com/llvm/llvm-project/pull/130988


More information about the llvm-commits mailing list