[llvm] [AMDGPU] Implement IR expansion for frem instruction (PR #130988)
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Fri Mar 28 09:01:24 PDT 2025
================
@@ -37,6 +48,352 @@ static cl::opt<unsigned>
cl::desc("fp convert instructions on integers with "
"more than <N> bits are expanded."));
+namespace {
+/// This class implements a precise expansion of the frem instruction.
+/// The generated code is based on the fmod implementation in the AMD device
+/// libs.
+class FRemExpander {
+ /// The IRBuilder to use for the expansion.
+ IRBuilder<> &B;
+
+ /// Floating point type of the return value and the arguments of the FRem
+ /// instructions that should be expanded.
+ Type *FremTy;
+
+ /// Floating point type to use for the computation. This may be
+ /// wider than the \p FremTy.
+ Type *ComputeFpTy;
+
+ /// Integer type that can hold floating point values of type \p FremTY.
+ Type *IntTy;
+
+ /// Integer type used to hold the exponents returned by frexp.
+ Type *ExTy;
+
+ /// How many bits of the quotient to compute per iteration of the
+ /// algorithm, stored as a value of type \p ExTy.
+ Value *Bits;
+
+ /// Constant 1 of type \p ExTy.
+ Value *One;
+
+ /// The sign bit for floating point values of type \p FremTy.
+ const unsigned long long Signbit;
+
+public:
+ static std::optional<FRemExpander> create(IRBuilder<> &B, Type *Ty) {
+ if (Ty->is16bitFPTy())
+ return FRemExpander{B, Ty, 11, 0x8000, B.getFloatTy(), B.getInt16Ty()};
+ if (Ty->isFloatTy() || Ty->isHalfTy())
+ return FRemExpander{B, Ty, 12, 0x80000000UL, Ty, B.getInt32Ty()};
+ if (Ty->isDoubleTy())
+ return FRemExpander{B, Ty, 26, 0x8000000000000000ULL, Ty, B.getInt64Ty()};
+
+ return std::nullopt;
+ }
+
+ /// Build the FRem expansion for the numerator \p X and the
+ /// denumerator \p Y using the builder \p B. The type of X and Y
+ /// must match the type for which the class instance has been
+ /// created. The code will be generated at the insertion point of \p
+ /// B and the insertion point will be reset at exit.
+ Value *buildFRem(Value *X, Value *Y, std::optional<SimplifyQuery> &SQ) const;
+
+private:
+ FRemExpander(IRBuilder<> &B, Type *FremTy, short Bits,
+ unsigned long long Signbit, Type *ComputeFpTy, Type *IntTy)
+ : B(B), FremTy(FremTy), ComputeFpTy(ComputeFpTy), IntTy(IntTy),
+ ExTy(B.getInt32Ty()), Bits(ConstantInt::get(ExTy, Bits)),
+ One(ConstantInt::get(ExTy, 1)), Signbit(Signbit) {};
+
+ Value *createRcp(Value *V, const Twine &Name) const {
+ // Leave it to later optimizations to turn this into an rcp
+ // instruction if available.
+ return B.CreateFDiv(ConstantFP::get(ComputeFpTy, 1.0), V, Name);
+ }
+
+ // Helper function to build the UPDATE_AX code which is common to the
+ // loop body and the "final iteration".
+ Value *buildUpdateAx(Value *Ax, Value *Ay, Value *Ayinv) const {
+ // Build:
+ // float q = BUILTIN_RINT_ComputeFpTy(ax * ayinv);
+ // ax = fnma(q, ay, ax);
+ // int clt = ax < 0.0f;
+ // float axp = ax + ay;
+ // ax = clt ? axp : ax;
+ Value *Q = B.CreateUnaryIntrinsic(Intrinsic::rint, B.CreateFMul(Ax, Ayinv),
+ {}, "q");
+ Value *AxUpdate = B.CreateFMA(B.CreateFNeg(Q), Ay, Ax, {}, "ax");
+ Value *Clt = B.CreateFCmp(CmpInst::FCMP_OLT, AxUpdate,
+ ConstantFP::get(ComputeFpTy, 0.0), "clt");
+ Value *Axp = B.CreateFAdd(AxUpdate, Ay, "axp");
+ AxUpdate = B.CreateSelect(Clt, Axp, AxUpdate, "ax");
+
+ return AxUpdate;
+ }
+
+ /// Build code to extract the exponent and mantissa of \p Src.
+ /// Return the exponent minus one for use as a loop bound and
+ /// the mantissa taken to the given \p NewExp power.
+ std::pair<Value *, Value *> buildExpAndPower(Value *Src, Value *NewExp,
+ const Twine &ExName,
+ const Twine &PowName) const {
+ // Build:
+ // ExName = BUILTIN_FREXP_EXP_ComputeFpTy(Src) - 1;
+ // PowName = BUILTIN_FLDEXP_ComputeFpTy(
+ // BUILTIN_FREXP_MANT_ComputeFpTy(ExName), NewExp);
+ Type *Ty = Src->getType();
+ Type *ExTy = B.getInt32Ty();
+ Value *Frexp = B.CreateIntrinsic(Intrinsic::frexp, {Ty, ExTy}, Src);
+ Value *Mant = B.CreateExtractValue(Frexp, {0});
+ Value *Exp = B.CreateExtractValue(Frexp, {1});
+
+ Exp = B.CreateSub(Exp, One, ExName);
+ Value *Pow = B.CreateLdexp(Mant, NewExp, {}, PowName);
+
+ return {Pow, Exp};
+ }
+
+ /// Build the main computation of the remainder for the case in which
+ /// Ax > Ay, where Ax = |X|, Ay = |Y|, and X is the numerator and Y the
+ /// denumerator. Add the incoming edge from the computation result
+ /// to \p RetPhi.
+ void buildRemainderComputation(Value *AxInitial, Value *AyInitial, Value *X,
+ PHINode *RetPhi) const {
+ // Build:
+ // ex = BUILTIN_FREXP_EXP_ComputeFpTy(ax) - 1;
+ // ax = BUILTIN_FLDEXP_ComputeFpTy(
+ // BUILTIN_FREXP_MANT_ComputeFpTy(ax), bits);
+ // ey = BUILTIN_FREXP_EXP_ComputeFpTy(ay) - 1;
+ // ay = BUILTIN_FLDEXP_ComputeFpTy(
+ // BUILTIN_FREXP_MANT_ComputeFpTy(ay), 1);
+ auto [Ax, Ex] = buildExpAndPower(AxInitial, Bits, "ex", "ax");
+ auto [Ay, Ey] = buildExpAndPower(AyInitial, One, "ey", "ay");
+
+ // Build:
+ // int nb = ex - ey;
+ // float ayinv = MATH_FAST_RCP(ay);
+ Value *Nb = B.CreateSub(Ex, Ey, "nb");
+ Value *Ayinv = createRcp(Ay, "ayinv");
+
+ // Build: while (nb > bits)
+ BasicBlock *PreheaderBB = B.GetInsertBlock();
+ Function *Fun = PreheaderBB->getParent();
+ auto *LoopBB = BasicBlock::Create(B.getContext(), "frem.loop_body", Fun);
+ auto *ExitBB = BasicBlock::Create(B.getContext(), "frem.loop_exit", Fun);
+
+ B.CreateCondBr(B.CreateICmp(CmpInst::ICMP_SGT, Nb, Bits), LoopBB, ExitBB);
+
+ // Build loop body:
+ // UPDATE_AX
+ // ax = BUILTIN_FLDEXP_ComputeFpTy(ax, bits);
+ // nb -= bits;
+ // One iteration of the loop is factored out. The code shared by
+ // the loop and this "iteration" is denoted by UPDATE_AX.
+ B.SetInsertPoint(LoopBB);
+ auto *NbIv = B.CreatePHI(Nb->getType(), 2, "nb_iv");
+ NbIv->addIncoming(Nb, PreheaderBB);
+
+ auto *AxPhi = B.CreatePHI(ComputeFpTy, 2, "ax_loop_phi");
+ AxPhi->addIncoming(Ax, PreheaderBB);
+
+ Value *AxPhiUpdate = buildUpdateAx(AxPhi, Ay, Ayinv);
+ AxPhiUpdate = B.CreateLdexp(AxPhiUpdate, Bits, {}, "ax_update");
+ AxPhi->addIncoming(AxPhiUpdate, LoopBB);
+ NbIv->addIncoming(B.CreateSub(NbIv, Bits, "nb_update"), LoopBB);
+
+ B.CreateCondBr(B.CreateICmp(CmpInst::ICMP_SGT, NbIv, Bits), LoopBB, ExitBB);
+
+ // Build final iteration
+ // ax = BUILTIN_FLDEXP_ComputeFpTy(ax, nb - bits + 1);
+ // UPDATE_AX
+ B.SetInsertPoint(ExitBB);
+
+ auto *AxPhiExit = B.CreatePHI(ComputeFpTy, 2, "ax_exit_phi");
+ AxPhiExit->addIncoming(Ax, PreheaderBB);
+ AxPhiExit->addIncoming(AxPhi, LoopBB);
+ auto *NbExitPhi = B.CreatePHI(Nb->getType(), 2, "nb_exit_phi");
+ NbExitPhi->addIncoming(NbIv, LoopBB);
+ NbExitPhi->addIncoming(Nb, PreheaderBB);
+
+ Value *AxFinal = B.CreateLdexp(
+ AxPhiExit, B.CreateAdd(B.CreateSub(NbExitPhi, Bits), One), {}, "ax");
+ AxFinal = buildUpdateAx(AxFinal, Ay, Ayinv);
+
+ // Build:
+ // ax = BUILTIN_FLDEXP_ComputeFpTy(ax, ey);
+ // ret = AS_FLOAT((AS_INT(x) & SIGNBIT_SP32) ^ AS_INT(ax));
+ AxFinal = B.CreateLdexp(AxFinal, Ey, {}, "ax");
+
+ Value *XAsInt = B.CreateBitCast(X, IntTy, "x_as_int");
+ if (ComputeFpTy != X->getType())
+ AxFinal = B.CreateFPTrunc(AxFinal, X->getType());
+
+ Value *AxAsInt = B.CreateBitCast(AxFinal, IntTy, "ax_as_int");
+
+ Value *Ret =
+ B.CreateXor(B.CreateAnd(XAsInt, Signbit), AxAsInt, "Remainder");
+ Ret = B.CreateBitCast(Ret, X->getType());
+
+ RetPhi->addIncoming(Ret, ExitBB);
+ }
+
+ /// Build the else-branch of the conditional in the FRem
+ /// expansion, i.e. the case in wich Ax <= Ay, where Ax = |X|, Ay
+ /// = |Y|, and X is the numerator and Y the denumerator. Add the
+ /// incoming edge from the result to \p RetPhi.
+ void buildElseBranch(Value *Ax, Value *Ay, Value *X, PHINode *RetPhi) const {
+ // Build:
+ // ret = ax == ay ? BUILTIN_COPYSIGN_ComputeFpTy(0.0f, x) : x;
+ Value *ZeroWithXSign = B.CreateIntrinsic(
+ Intrinsic::copysign, {FremTy}, {ConstantFP::get(FremTy, 0.0), X}, {});
+
+ Value *Ret = B.CreateSelect(B.CreateFCmpOEQ(Ax, Ay), ZeroWithXSign, X);
+
+ RetPhi->addIncoming(Ret, B.GetInsertBlock());
+ }
+
+ /// Return a value that is NaN if one of the corner cases concerning
+ /// the inputs \p X and \p Y is detected, and \p Ret otherwise.
+ Value *handleInputCornerCases(Value *Ret, Value *X, Value *Y,
+ std::optional<SimplifyQuery> &SQ,
+ bool NoInfs) const {
+ // Build:
+ // ret = y == 0.0f ? QNAN_ComputeFpTy : ret;
+ // bool c = !BUILTIN_ISNAN_ComputeFpTy(y) &&
+ // BUILTIN_ISFINITE_ComputeFpTy(x);
+ // ret = c ? ret : QNAN_ComputeFpTy;
+ Value *Nan = ConstantFP::getQNaN(FremTy);
+ Ret = B.CreateSelect(B.CreateFCmpOEQ(Y, ConstantFP::get(FremTy, 0.0)), Nan,
+ Ret);
+ FPClassTest NotNan = FPClassTest::fcInf | FPClassTest::fcFinite;
+ Value *YNotNan = SQ && isKnownNeverNaN(Y, 0, *SQ)
+ ? B.getTrue()
+ : B.createIsFPClass(Y, NotNan);
----------------
arsenm wrote:
not nan should use fcmp ord instead
https://github.com/llvm/llvm-project/pull/130988
More information about the llvm-commits
mailing list