[llvm] [AMDGPU] Implement IR expansion for frem instruction (PR #130988)
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Mon Apr 21 06:05:06 PDT 2025
================
@@ -37,6 +49,340 @@ static cl::opt<unsigned>
cl::desc("fp convert instructions on integers with "
"more than <N> bits are expanded."));
+namespace {
+/// This class implements a precise expansion of the frem instruction.
+/// The generated code is based on the fmod implementation in the AMD device
+/// libs.
+class FRemExpander {
+ /// The IRBuilder to use for the expansion.
+ IRBuilder<> &B;
+
+ /// Floating point type of the return value and the arguments of the FRem
+ /// instructions that should be expanded.
+ Type *FremTy;
+
+ /// Floating point type to use for the computation. This may be
+ /// wider than the \p FremTy.
+ Type *ComputeFpTy;
+
+ /// Integer type used to hold the exponents returned by frexp.
+ Type *ExTy;
+
+ /// How many bits of the quotient to compute per iteration of the
+ /// algorithm, stored as a value of type \p ExTy.
+ Value *Bits;
+
+ /// Constant 1 of type \p ExTy.
+ Value *One;
+
+public:
+ static std::optional<FRemExpander> create(IRBuilder<> &B, Type *Ty) {
+ // TODO The expansion should work for other types as well, but
+ // this would require additional testing.
+ if (!Ty->isIEEELikeFPTy() || Ty->isBFloatTy() || Ty->isFP128Ty())
+ return std::nullopt;
+
+ // The type to use for the computation of the remainder. This may be
+ // wider than the input/result type which affects the ...
+ Type *ComputeTy = Ty;
+ // ... maximum number of iterations of the remainder computation loop
+ // to use. This value is for the case in which the computation
+ // uses the same input/result type.
+ unsigned MaxIter = 2;
+
+ if (Ty->is16bitFPTy()) {
+ // Use the wider type and less iterations.
+ ComputeTy = B.getFloatTy();
+ MaxIter = 1;
+ }
+
+ unsigned Precision =
+ llvm::APFloat::semanticsPrecision(Ty->getFltSemantics());
+ return FRemExpander{B, Ty, Precision / MaxIter, ComputeTy};
+ }
+
+ /// Build the FRem expansion for the numerator \p X and the
+ /// denumerator \p Y using the builder \p B. The type of X and Y
+ /// must match the type for which the class instance has been
+ /// created. The code will be generated at the insertion point of \p
+ /// B and the insertion point will be reset at exit.
+ Value *buildFRem(Value *X, Value *Y, std::optional<SimplifyQuery> &SQ) const;
+
+private:
+ FRemExpander(IRBuilder<> &B, Type *FremTy, unsigned Bits, Type *ComputeFpTy)
+ : B(B), FremTy(FremTy), ComputeFpTy(ComputeFpTy), ExTy(B.getInt32Ty()),
+ Bits(ConstantInt::get(ExTy, Bits)), One(ConstantInt::get(ExTy, 1)) {};
+
+ Value *createRcp(Value *V, const Twine &Name) const {
+ // Leave it to later optimizations to turn this into an rcp
+ // instruction if available.
+ return B.CreateFDiv(ConstantFP::get(ComputeFpTy, 1.0), V, Name);
+ }
+
+ // Helper function to build the UPDATE_AX code which is common to the
+ // loop body and the "final iteration".
+ Value *buildUpdateAx(Value *Ax, Value *Ay, Value *Ayinv) const {
+ // Build:
+ // float q = rint(ax * ayinv);
+ // ax = fma(-q, ay, ax);
----------------
arsenm wrote:
We appear to be missing folds of fneg with rint. In this case this works out, you can't save encoding size if you pull the fneg away from the fma. But in general we should probably have `fneg (rint(x)) -> rint(fneg(x))`
https://github.com/llvm/llvm-project/pull/130988
More information about the llvm-commits
mailing list