[llvm] [AMDGPU] Implement IR expansion for frem instruction (PR #130988)
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Wed Mar 12 18:19:36 PDT 2025
================
@@ -37,6 +41,377 @@ static cl::opt<unsigned>
cl::desc("fp convert instructions on integers with "
"more than <N> bits are expanded."));
+namespace {
+/// This class implements a precise expansion of the frem instruction.
+/// The generated code is based on the fmod implementation in the AMD device
+/// libs.
+class FRemExpander {
+ /// The IRBuilder to use for the expansion.
+ IRBuilder<> &B;
+
+ /// Floating point type of the return value and the arguments of the FRem
+ /// instructions that should be expanded.
+ Type *FremTy;
+
+ /// Floating point type to use for the computation. This may be
+ /// wider than the \p FremTy.
+ Type *ComputeFpTy;
+
+ /// Integer type that can hold floating point values of type \p FremTY.
+ Type *IntTy;
+
+ /// Integer type used to hold the exponents returned by frexp.
+ Type *ExTy;
+
+ /// How many bits of the quotient to compute per iteration of the
+ /// algorithm, stored as a value of type \p ExTy.
+ Value *Bits;
+
+ /// Constant 1 of type \p ExTy.
+ Value *One;
+
+ /// The sign bit for floating point values of type \p FremTy.
+ const unsigned long Signbit;
+
+public:
+ static std::optional<FRemExpander> create(IRBuilder<> &B, Type *Ty) {
+ if (Ty->is16bitFPTy())
+ return FRemExpander{B, Ty, 11, 0x8000, B.getFloatTy(), B.getInt16Ty()};
+ if (Ty->isFloatTy() || Ty->isHalfTy())
+ return FRemExpander{B, Ty, 12, 0x80000000L, Ty, B.getInt32Ty()};
+ if (Ty->isDoubleTy())
+ return FRemExpander{B, Ty, 26, 0x8000000000000000L, Ty, B.getInt64Ty()};
+
+ return std::nullopt;
+ }
+
+ /// Build the FRem expansion for the numerator \p X and the
+ /// denumerator \p Y using the builder \p B. The type of X and Y
+ /// must match the type for which the class instance has been
+ /// created. The code will be generated at the insertion point of \p
+ /// B and the insertion point will be reset at exit.
+ Value *buildFRem(Value *X, Value *Y) const;
+
+private:
+ FRemExpander(IRBuilder<> &B, Type *FremTy, short Bits, unsigned long Signbit,
+ Type *ComputeFpTy, Type *IntTy)
+ : B(B), FremTy(FremTy), ComputeFpTy(ComputeFpTy), IntTy(IntTy),
+ ExTy(B.getInt32Ty()), Bits(ConstantInt::get(ExTy, Bits)),
+ One(ConstantInt::get(ExTy, 1)), Signbit(Signbit) {};
+
+ Value *createLdexp(Value *Base, Value *Exp, const Twine &Name) const {
+ return B.CreateIntrinsic(Intrinsic::ldexp, {ComputeFpTy, B.getInt32Ty()},
+ {Base, Exp}, {}, Name);
+ }
+
+ Value *createRcp(Value *V, const Twine &Name) const {
+ return B.CreateFDiv(ConstantFP::get(ComputeFpTy, 1.0), V, Name);
+ }
+
+ // Helper function to build the UPDATE_AX code which is common to the
+ // loop body and the "final iteration".
+ Value *buildUpdateAx(Value *Ax, Value *Ay, Value *Ayinv) const {
+ // Build:
+ // float q = BUILTIN_RINT_ComputeFpTy(ax * ayinv);
+ // ax = fnma(q, ay, ax);
+ // int clt = ax < 0.0f;
+ // float axp = ax + ay;
+ // ax = clt ? axp : ax;
+ Value *Q = B.CreateUnaryIntrinsic(Intrinsic::rint, B.CreateFMul(Ax, Ayinv),
+ {}, "q");
+ Value *AxUpdate = B.CreateIntrinsic(Intrinsic::fma, {ComputeFpTy},
+ {B.CreateFNeg(Q), Ay, Ax}, {}, "ax");
+ Value *Clt = B.CreateFCmp(CmpInst::FCMP_OLT, AxUpdate,
+ ConstantFP::get(ComputeFpTy, 0.0), "clt");
+ Value *Axp = B.CreateFAdd(AxUpdate, Ay, "axp");
+ AxUpdate = B.CreateSelect(Clt, Axp, AxUpdate, "ax");
+
+ return AxUpdate;
+ }
+
+ /// Build code to extract the exponent and mantissa of \p Src.
+ /// Return the exponent minus one for use as a loop bound and
+ /// the mantissa taken to the given \p NewExp power.
+ std::pair<Value *, Value *> buildExpAndPower(Value *Src, Value *NewExp,
+ const Twine &ExName,
+ const Twine &PowName) const {
+ // Build:
+ // ExName = BUILTIN_FREXP_EXP_ComputeFpTy(Src) - 1;
+ // PowName = BUILTIN_FLDEXP_ComputeFpTy(
+ // BUILTIN_FREXP_MANT_ComputeFpTy(ExName), NewExp);
+ Type *Ty = Src->getType();
+ Type *ExTy = B.getInt32Ty();
+ Value *Frexp = B.CreateIntrinsic(Intrinsic::frexp, {Ty, ExTy}, Src);
+ Value *Mant = B.CreateExtractValue(Frexp, {0});
+ Value *Exp = B.CreateExtractValue(Frexp, {1});
+
+ Exp = B.CreateSub(Exp, One, ExName);
+ Value *Pow = createLdexp(Mant, NewExp, PowName);
+
+ return {Pow, Exp};
+ }
+
+ /// Build the main computation of the remainder for the case in which
+ /// Ax > Ay, where Ax = |X|, Ay = |Y|, and X is the numerator and Y the
+ /// denumerator. Add the incoming edge from the computation result
+ /// to \p RetPhi.
+ void buildRemainderComputation(Value *AxInitial, Value *AyInitial, Value *X,
+ PHINode *RetPhi) const {
+ // Build:
+ // ex = BUILTIN_FREXP_EXP_ComputeFpTy(ax) - 1;
+ // ax = BUILTIN_FLDEXP_ComputeFpTy(
+ // BUILTIN_FREXP_MANT_ComputeFpTy(ax), bits);
+ // ey = BUILTIN_FREXP_EXP_ComputeFpTy(ay) - 1;
+ // ay = BUILTIN_FLDEXP_ComputeFpTy(
+ // BUILTIN_FREXP_MANT_ComputeFpTy(ay), 1);
+ auto [Ax, Ex] = buildExpAndPower(AxInitial, Bits, "ex", "ax");
+ auto [Ay, Ey] = buildExpAndPower(AyInitial, One, "ey", "ay");
+
+ // Build:
+ // int nb = ex - ey;
+ // float ayinv = MATH_FAST_RCP(ay);
+ Value *Nb = B.CreateSub(Ex, Ey, "nb");
+ Value *Ayinv = createRcp(Ay, "ayinv");
+
+ // Build: while (nb > bits)
+ BasicBlock *PreheaderBB = B.GetInsertBlock();
+ Function *Fun = PreheaderBB->getParent();
+ auto *LoopBB = BasicBlock::Create(B.getContext(), "frem.loop_body", Fun);
+ auto *ExitBB = BasicBlock::Create(B.getContext(), "frem.loop_exit", Fun);
+
+ B.CreateCondBr(B.CreateICmp(CmpInst::ICMP_SGT, Nb, Bits), LoopBB, ExitBB);
+
+ // Build loop body:
+ // UPDATE_AX
+ // ax = BUILTIN_FLDEXP_ComputeFpTy(ax, bits);
+ // nb -= bits;
+ // One iteration of the loop is factored out. The code shared by
+ // the loop and this "iteration" is denoted by UPDATE_AX.
+ B.SetInsertPoint(LoopBB);
+ auto *NbIv = B.CreatePHI(Nb->getType(), 2, "nb_iv");
+ NbIv->addIncoming(Nb, PreheaderBB);
+
+ auto *AxPhi = B.CreatePHI(ComputeFpTy, 2, "ax_loop_phi");
+ AxPhi->addIncoming(Ax, PreheaderBB);
+
+ Value *AxPhiUpdate = buildUpdateAx(AxPhi, Ay, Ayinv);
+ AxPhiUpdate = createLdexp(AxPhiUpdate, Bits, "ax_update");
+ AxPhi->addIncoming(AxPhiUpdate, LoopBB);
+ NbIv->addIncoming(B.CreateSub(NbIv, Bits, "nb_update"), LoopBB);
+
+ B.CreateCondBr(B.CreateICmp(CmpInst::ICMP_SGT, NbIv, Bits), LoopBB, ExitBB);
+
+ // Build final iteration
+ // ax = BUILTIN_FLDEXP_ComputeFpTy(ax, nb - bits + 1);
+ // UPDATE_AX
+ B.SetInsertPoint(ExitBB);
+
+ auto *AxPhiExit = B.CreatePHI(ComputeFpTy, 2, "ax_exit_phi");
+ AxPhiExit->addIncoming(Ax, PreheaderBB);
+ AxPhiExit->addIncoming(AxPhi, LoopBB);
+ auto *NbExitPhi = B.CreatePHI(Nb->getType(), 2, "nb_exit_phi");
+ NbExitPhi->addIncoming(NbIv, LoopBB);
+ NbExitPhi->addIncoming(Nb, PreheaderBB);
+
+ Value *AxFinal = createLdexp(
+ AxPhiExit, B.CreateAdd(B.CreateSub(NbExitPhi, Bits), One), "ax");
+ AxFinal = buildUpdateAx(AxFinal, Ay, Ayinv);
+
+ // Build:
+ // ax = BUILTIN_FLDEXP_ComputeFpTy(ax, ey);
+ // ret = AS_FLOAT((AS_INT(x) & SIGNBIT_SP32) ^ AS_INT(ax));
+ AxFinal = createLdexp(AxFinal, Ey, "ax");
+
+ Value *XAsInt = B.CreateBitCast(X, IntTy, "x_as_int");
+ if (ComputeFpTy != X->getType())
+ AxFinal = B.CreateFPTrunc(AxFinal, X->getType());
+
+ Value *AxAsInt = B.CreateBitCast(AxFinal, IntTy, "ax_as_int");
+
+ Value *Ret =
+ B.CreateXor(B.CreateAnd(XAsInt, Signbit), AxAsInt, "Remainder");
+ Ret = B.CreateBitCast(Ret, X->getType());
+
+ RetPhi->addIncoming(Ret, ExitBB);
+ }
+
+ /// Build the else-branch of the conditional in the FRem
+ /// expansion, i.e. the case in wich Ax <= Ay, where Ax = |X|, Ay
+ /// = |Y|, and X is the numerator and Y the denumerator. Add the
+ /// incoming edge from the result to \p RetPhi.
+ void buildElseBranch(Value *Ax, Value *Ay, Value *X, PHINode *RetPhi) const {
+ // Build:
+ // ret = ax == ay ? BUILTIN_COPYSIGN_ComputeFpTy(0.0f, x) : x;
+ Value *ZeroWithXSign = B.CreateIntrinsic(
+ Intrinsic::copysign, {FremTy}, {ConstantFP::get(FremTy, 0.0), X}, {});
+
+ Value *Ret = B.CreateSelect(B.CreateFCmpOEQ(Ax, Ay), ZeroWithXSign, X);
+
+ RetPhi->addIncoming(Ret, B.GetInsertBlock());
+ }
+
+ /// Adjust the result of the main computation from the FRem expansion
+ /// if NaNs or infinite values are possible.
+ Value *buildNanAndInfHandling(Value *Ret, Value *X, Value *Y) const {
+ // Build:
+ // ret = y == 0.0f ? QNAN_ComputeFpTy : ret;
+ // bool c = !BUILTIN_ISNAN_ComputeFpTy(y) &&
+ // BUILTIN_ISFINITE_ComputeFpTy(x);
+ // ret = c ? ret : QNAN_ComputeFpTy;
+ // TODO Handle NaN and infinity fast math flags separately here?
+ Value *Nan = ConstantFP::getQNaN(FremTy);
+
+ Ret = B.CreateSelect(B.createIsFPClass(Y, FPClassTest::fcZero), Nan, Ret);
+ Value *C = B.CreateLogicalAnd(
+ B.CreateNot(B.createIsFPClass(Y, FPClassTest::fcNan)),
+ B.createIsFPClass(X, FPClassTest::fcFinite));
+ Ret = B.CreateSelect(C, Ret, Nan);
+
+ return Ret;
+ }
+};
+
+Value *FRemExpander::buildFRem(Value *X, Value *Y) const {
+ assert(X->getType() == FremTy && Y->getType() == FremTy);
+
+ FastMathFlags FMF = B.getFastMathFlags();
+
+ // This function generates the following code structure:
+ // if (abs(x) > abs(y))
+ // { ret = compute remainder }
+ // else
+ // { ret = x or 0 with sign of x }
+ // Adjust ret to NaN/inf in input
+ // return ret
+ Value *Ax = B.CreateUnaryIntrinsic(Intrinsic::fabs, X, {}, "ax");
+ Value *Ay = B.CreateUnaryIntrinsic(Intrinsic::fabs, Y, {}, "ay");
+ if (ComputeFpTy != X->getType()) {
+ Ax = B.CreateFPExt(Ax, ComputeFpTy, "ax");
+ Ay = B.CreateFPExt(Ay, ComputeFpTy, "ay");
+ }
+ Value *AxAyCmp = B.CreateFCmpOGT(Ax, Ay);
+
+ PHINode *RetPhi = B.CreatePHI(FremTy, 2, "ret");
+ Value *Ret = RetPhi;
+
+ if (!FMF.noNaNs() || !FMF.noInfs())
----------------
arsenm wrote:
TODO use isKnownNeverNaN
https://github.com/llvm/llvm-project/pull/130988
More information about the llvm-commits
mailing list