[llvm] [AMDGPU] Implement IR expansion for frem instruction (PR #130988)
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Sun Apr 13 00:31:37 PDT 2025
================
@@ -37,6 +49,342 @@ static cl::opt<unsigned>
cl::desc("fp convert instructions on integers with "
"more than <N> bits are expanded."));
+namespace {
+/// This class implements a precise expansion of the frem instruction.
+/// The generated code is based on the fmod implementation in the AMD device
+/// libs.
+class FRemExpander {
+ /// The IRBuilder to use for the expansion.
+ IRBuilder<> &B;
+
+ /// Floating point type of the return value and the arguments of the FRem
+ /// instructions that should be expanded.
+ Type *FremTy;
+
+ /// Floating point type to use for the computation. This may be
+ /// wider than the \p FremTy.
+ Type *ComputeFpTy;
+
+ /// Integer type used to hold the exponents returned by frexp.
+ Type *ExTy;
+
+ /// How many bits of the quotient to compute per iteration of the
+ /// algorithm, stored as a value of type \p ExTy.
+ Value *Bits;
+
+ /// Constant 1 of type \p ExTy.
+ Value *One;
+
+public:
+ static std::optional<FRemExpander> create(IRBuilder<> &B, Type *Ty) {
+ // TODO The expansion should work for other types as well, but
+ // this would require additional testing.
+ if (!Ty->isIEEELikeFPTy() || Ty->isBFloatTy() || Ty->isFP128Ty())
+ return std::nullopt;
+
+ // The type to use for the computation of the remainder. This may be
+ // wider than the input/result type which affects the ...
+ Type *ComputeTy = Ty;
+ // ... maximum number of iterations of the remainder computation loop
+ // to use. This value is for the case in which the computation
+ // uses the same input/result type.
+ unsigned MaxIter = 2;
+
+ if (Ty->is16bitFPTy()) {
+ // Use the wider type and less iterations.
+ ComputeTy = B.getFloatTy();
+ MaxIter = 1;
+ }
+
+ unsigned Precision =
+ llvm::APFloat::semanticsPrecision(Ty->getFltSemantics());
+ return FRemExpander{B, Ty, Precision / MaxIter, ComputeTy};
+ }
+
+ /// Build the FRem expansion for the numerator \p X and the
+ /// denumerator \p Y using the builder \p B. The type of X and Y
+ /// must match the type for which the class instance has been
+ /// created. The code will be generated at the insertion point of \p
+ /// B and the insertion point will be reset at exit.
+ Value *buildFRem(Value *X, Value *Y, std::optional<SimplifyQuery> &SQ) const;
+
+private:
+ FRemExpander(IRBuilder<> &B, Type *FremTy, unsigned Bits, Type *ComputeFpTy)
+ : B(B), FremTy(FremTy), ComputeFpTy(ComputeFpTy), ExTy(B.getInt32Ty()),
+ Bits(ConstantInt::get(ExTy, Bits)), One(ConstantInt::get(ExTy, 1)) {};
+
+ Value *createRcp(Value *V, const Twine &Name) const {
+ // Leave it to later optimizations to turn this into an rcp
+ // instruction if available.
+ return B.CreateFDiv(ConstantFP::get(ComputeFpTy, 1.0), V, Name);
+ }
+
+ // Helper function to build the UPDATE_AX code which is common to the
+ // loop body and the "final iteration".
+ Value *buildUpdateAx(Value *Ax, Value *Ay, Value *Ayinv) const {
+ // Build:
+ // float q = rint(ax * ayinv);
+ // ax = fma(-q, ay, ax);
+ // int clt = ax < 0.0f;
+ // float axp = ax + ay;
+ // ax = clt ? axp : ax;
+ Value *Q = B.CreateUnaryIntrinsic(Intrinsic::rint, B.CreateFMul(Ax, Ayinv),
+ {}, "q");
+ Value *AxUpdate = B.CreateFMA(B.CreateFNeg(Q), Ay, Ax, {}, "ax");
+ Value *Clt = B.CreateFCmp(CmpInst::FCMP_OLT, AxUpdate,
+ ConstantFP::get(ComputeFpTy, 0.0), "clt");
+ Value *Axp = B.CreateFAdd(AxUpdate, Ay, "axp");
+ AxUpdate = B.CreateSelect(Clt, Axp, AxUpdate, "ax");
+
+ return AxUpdate;
+ }
+
+ /// Build code to extract the exponent and mantissa of \p Src.
+ /// Return the exponent minus one for use as a loop bound and
+ /// the mantissa taken to the given \p NewExp power.
+ std::pair<Value *, Value *> buildExpAndPower(Value *Src, Value *NewExp,
+ const Twine &ExName,
+ const Twine &PowName) const {
+ // Build:
+ // ExName = frexp_exp(Src) - 1;
+ // PowName = fldexp(frexp_mant(ExName), NewExp);
+ Type *Ty = Src->getType();
+ Type *ExTy = B.getInt32Ty();
+ Value *Frexp = B.CreateIntrinsic(Intrinsic::frexp, {Ty, ExTy}, Src);
+ Value *Mant = B.CreateExtractValue(Frexp, {0});
+ Value *Exp = B.CreateExtractValue(Frexp, {1});
+
+ Exp = B.CreateSub(Exp, One, ExName);
+ Value *Pow = B.CreateLdexp(Mant, NewExp, {}, PowName);
+
+ return {Pow, Exp};
+ }
+
+ /// Build the main computation of the remainder for the case in which
+ /// Ax > Ay, where Ax = |X|, Ay = |Y|, and X is the numerator and Y the
+ /// denumerator. Add the incoming edge from the computation result
+ /// to \p RetPhi.
+ void buildRemainderComputation(Value *AxInitial, Value *AyInitial, Value *X,
+ PHINode *RetPhi, FastMathFlags FMF) const {
+ IRBuilder<>::FastMathFlagGuard Guard(B);
+ B.setFastMathFlags(FMF);
+
+ // Build:
+ // ex = frexp_exp(ax) - 1;
+ // ax = fldexp(frexp_mant(ax), bits);
+ // ey = frexp_exp(ay) - 1;
+ // ay = fledxp(frexp_mant(ay), 1);
+ auto [Ax, Ex] = buildExpAndPower(AxInitial, Bits, "ex", "ax");
+ auto [Ay, Ey] = buildExpAndPower(AyInitial, One, "ey", "ay");
+
+ // Build:
+ // int nb = ex - ey;
+ // float ayinv = 1.0/ay;
+ Value *Nb = B.CreateSub(Ex, Ey, "nb");
+ Value *Ayinv = createRcp(Ay, "ayinv");
+
+ // Build: while (nb > bits)
+ BasicBlock *PreheaderBB = B.GetInsertBlock();
+ Function *Fun = PreheaderBB->getParent();
+ auto *LoopBB = BasicBlock::Create(B.getContext(), "frem.loop_body", Fun);
+ auto *ExitBB = BasicBlock::Create(B.getContext(), "frem.loop_exit", Fun);
+
+ B.CreateCondBr(B.CreateICmp(CmpInst::ICMP_SGT, Nb, Bits), LoopBB, ExitBB);
+
+ // Build loop body:
+ // UPDATE_AX
+ // ax = fldexp(ax, bits);
+ // nb -= bits;
+ // One iteration of the loop is factored out. The code shared by
+ // the loop and this "iteration" is denoted by UPDATE_AX.
+ B.SetInsertPoint(LoopBB);
+ auto *NbIv = B.CreatePHI(Nb->getType(), 2, "nb_iv");
+ NbIv->addIncoming(Nb, PreheaderBB);
+
+ auto *AxPhi = B.CreatePHI(ComputeFpTy, 2, "ax_loop_phi");
+ AxPhi->addIncoming(Ax, PreheaderBB);
+
+ Value *AxPhiUpdate = buildUpdateAx(AxPhi, Ay, Ayinv);
+ AxPhiUpdate = B.CreateLdexp(AxPhiUpdate, Bits, {}, "ax_update");
+ AxPhi->addIncoming(AxPhiUpdate, LoopBB);
+ NbIv->addIncoming(B.CreateSub(NbIv, Bits, "nb_update"), LoopBB);
+
+ B.CreateCondBr(B.CreateICmp(CmpInst::ICMP_SGT, NbIv, Bits), LoopBB, ExitBB);
+
+ // Build final iteration
+ // ax = fldexp(ax, nb - bits + 1);
+ // UPDATE_AX
+ B.SetInsertPoint(ExitBB);
+
+ auto *AxPhiExit = B.CreatePHI(ComputeFpTy, 2, "ax_exit_phi");
+ AxPhiExit->addIncoming(Ax, PreheaderBB);
+ AxPhiExit->addIncoming(AxPhi, LoopBB);
+ auto *NbExitPhi = B.CreatePHI(Nb->getType(), 2, "nb_exit_phi");
+ NbExitPhi->addIncoming(NbIv, LoopBB);
+ NbExitPhi->addIncoming(Nb, PreheaderBB);
+
+ Value *AxFinal = B.CreateLdexp(
+ AxPhiExit, B.CreateAdd(B.CreateSub(NbExitPhi, Bits), One), {}, "ax");
+ AxFinal = buildUpdateAx(AxFinal, Ay, Ayinv);
+
+ // Build:
+ // ax = fldexp(ax, ey);
+ // ret = copysign(ax,x);
+ AxFinal = B.CreateLdexp(AxFinal, Ey, {}, "ax");
+ if (ComputeFpTy != FremTy)
+ AxFinal = B.CreateFPTrunc(AxFinal, FremTy);
+ Value *Ret = B.CreateCopySign(AxFinal, X);
+
+ RetPhi->addIncoming(Ret, ExitBB);
+ }
+
+ /// Build the else-branch of the conditional in the FRem
+ /// expansion, i.e. the case in wich Ax <= Ay, where Ax = |X|, Ay
+ /// = |Y|, and X is the numerator and Y the denumerator. Add the
+ /// incoming edge from the result to \p RetPhi.
+ void buildElseBranch(Value *Ax, Value *Ay, Value *X, PHINode *RetPhi) const {
+ // Build:
+ // ret = ax == ay ? copysign(0.0f, x) : x;
+ Value *ZeroWithXSign = B.CreateCopySign(ConstantFP::get(FremTy, 0.0), X);
+ Value *Ret = B.CreateSelect(B.CreateFCmpOEQ(Ax, Ay), ZeroWithXSign, X);
+
+ RetPhi->addIncoming(Ret, B.GetInsertBlock());
+ }
+
+ /// Return a value that is NaN if one of the corner cases concerning
+ /// the inputs \p X and \p Y is detected, and \p Ret otherwise.
+ Value *handleInputCornerCases(Value *Ret, Value *X, Value *Y,
+ std::optional<SimplifyQuery> &SQ,
+ bool NoInfs) const {
+ // Build:
+ // ret = (y == 0.0f || isnan(y)) ? QNAN : ret;
+ // ret = isfinite(x) ? ret : QNAN;
+ Value *Nan = ConstantFP::getQNaN(FremTy);
+ Ret = B.CreateSelect(B.CreateFCmpUEQ(Y, ConstantFP::get(FremTy, 0.0)), Nan,
+ Ret);
+ Value *XFinite =
+ NoInfs || (SQ && isKnownNeverInfinity(X, 0, *SQ))
+ ? B.getTrue()
+ : B.CreateFCmpULT(B.CreateUnaryIntrinsic(Intrinsic::fabs, X),
+ ConstantFP::getInfinity(FremTy));
+ Ret = B.CreateSelect(XFinite, Ret, Nan);
+
+ return Ret;
+ }
+};
+
+Value *FRemExpander::buildFRem(Value *X, Value *Y,
+ std::optional<SimplifyQuery> &SQ) const {
+ assert(X->getType() == FremTy && Y->getType() == FremTy);
+
+ FastMathFlags FMF = B.getFastMathFlags();
+
+ // This function generates the following code structure:
+ // if (abs(x) > abs(y))
+ // { ret = compute remainder }
+ // else
+ // { ret = x or 0 with sign of x }
+ // Adjust ret to NaN/inf in input
+ // return ret
+ Value *Ax = B.CreateUnaryIntrinsic(Intrinsic::fabs, X, {}, "ax");
+ Value *Ay = B.CreateUnaryIntrinsic(Intrinsic::fabs, Y, {}, "ay");
+ if (ComputeFpTy != X->getType()) {
+ Ax = B.CreateFPExt(Ax, ComputeFpTy, "ax");
+ Ay = B.CreateFPExt(Ay, ComputeFpTy, "ay");
+ }
+ Value *AxAyCmp = B.CreateFCmpOGT(Ax, Ay);
+
+ PHINode *RetPhi = B.CreatePHI(FremTy, 2, "ret");
+ Value *Ret = RetPhi;
+
+ // We would return NaN in all corner cases handled here.
+ // Hence, if NaNs are excluded, keep the result as it is.
+ if (!FMF.noNaNs())
+ Ret = handleInputCornerCases(Ret, X, Y, SQ, FMF.noInfs());
+
+ Function *Fun = B.GetInsertBlock()->getParent();
+ auto *ThenBB = BasicBlock::Create(B.getContext(), "frem.compute", Fun);
+ auto *ElseBB = BasicBlock::Create(B.getContext(), "frem.else", Fun);
+ SplitBlockAndInsertIfThenElse(AxAyCmp, RetPhi, &ThenBB, &ElseBB);
+
+ auto SavedInsertPt = B.GetInsertPoint();
+
+ // Build remainder computation for "then" branch
+ //
+ // The ordered comparison ensures that ax and ay are not NaNs
+ // in the then-branch. Furthermore, y cannot be an infinity and the
+ // check at the end of the function ensures that the result will not
+ // be used if x is an infinity.
+ FastMathFlags ComputeFMF = FMF;
+ ComputeFMF.setNoInfs();
+ ComputeFMF.setNoNaNs();
+
+ B.SetInsertPoint(ThenBB);
+ buildRemainderComputation(Ax, Ay, X, RetPhi, FMF);
+ B.CreateBr(RetPhi->getParent());
+
+ // Build "else"-branch
+ B.SetInsertPoint(ElseBB);
+ buildElseBranch(Ax, Ay, X, RetPhi);
+ B.CreateBr(RetPhi->getParent());
+
+ B.SetInsertPoint(SavedInsertPt);
+
+ return Ret;
+}
+} // namespace
+
+static bool expandFRem(BinaryOperator &I, std::optional<SimplifyQuery> &SQ) {
----------------
arsenm wrote:
Eventually should handle strictfp too. I assume this bakes in the rounding mode somewhere
https://github.com/llvm/llvm-project/pull/130988
More information about the llvm-commits
mailing list