[llvm] 8f38138 - AMDGPU: Refactor libcall simplify to help with future refined fast math flag usage
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Mon Jul 31 08:23:21 PDT 2023
Author: Matt Arsenault
Date: 2023-07-31T11:23:12-04:00
New Revision: 8f381380905c159bafb8db3b9ce699ec182e892c
URL: https://github.com/llvm/llvm-project/commit/8f381380905c159bafb8db3b9ce699ec182e892c
DIFF: https://github.com/llvm/llvm-project/commit/8f381380905c159bafb8db3b9ce699ec182e892c.diff
LOG: AMDGPU: Refactor libcall simplify to help with future refined fast math flag usage
https://reviews.llvm.org/D156678
Added:
Modified:
llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp b/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
index f6f187923e61ff..a9dc094c2cfafb 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
@@ -51,6 +51,8 @@ class AMDGPULibCalls {
const TargetMachine *TM;
+ bool UnsafeFPMath = false;
+
// -fuse-native.
bool AllNative = false;
@@ -73,10 +75,10 @@ class AMDGPULibCalls {
bool fold_divide(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
// pow/powr/pown
- bool fold_pow(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
+ bool fold_pow(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo);
// rootn
- bool fold_rootn(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
+ bool fold_rootn(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo);
// fma/mad
bool fold_fma_mad(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
@@ -90,10 +92,10 @@ class AMDGPULibCalls {
bool evaluateCall(CallInst *aCI, const FuncInfo &FInfo);
// sqrt
- bool fold_sqrt(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
+ bool fold_sqrt(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo);
// sin/cos
- bool fold_sincos(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo,
+ bool fold_sincos(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo,
AliasAnalysis *AA);
// __read_pipe/__write_pipe
@@ -113,7 +115,9 @@ class AMDGPULibCalls {
protected:
CallInst *CI;
- bool isUnsafeMath(const CallInst *CI) const;
+ bool isUnsafeMath(const FPMathOperator *FPOp) const;
+
+ bool canIncreasePrecisionOfConstantFold(const FPMathOperator *FPOp) const;
void replaceCall(Value *With) {
CI->replaceAllUsesWith(With);
@@ -125,6 +129,7 @@ class AMDGPULibCalls {
bool fold(CallInst *CI, AliasAnalysis *AA = nullptr);
+ void initFunction(const Function &F);
void initNativeFuncs();
// Replace a normal math function call with that native version
@@ -445,13 +450,18 @@ bool AMDGPULibCalls::parseFunctionName(const StringRef &FMangledName,
return AMDGPULibFunc::parse(FMangledName, FInfo);
}
-bool AMDGPULibCalls::isUnsafeMath(const CallInst *CI) const {
- if (auto Op = dyn_cast<FPMathOperator>(CI))
- if (Op->isFast())
- return true;
- const Function *F = CI->getParent()->getParent();
- Attribute Attr = F->getFnAttribute("unsafe-fp-math");
- return Attr.getValueAsBool();
+bool AMDGPULibCalls::isUnsafeMath(const FPMathOperator *FPOp) const {
+ return UnsafeFPMath || FPOp->isFast();
+}
+
+bool AMDGPULibCalls::canIncreasePrecisionOfConstantFold(
+ const FPMathOperator *FPOp) const {
+ // TODO: Refine to approxFunc or contract
+ return isUnsafeMath(FPOp);
+}
+
+void AMDGPULibCalls::initFunction(const Function &F) {
+ UnsafeFPMath = F.getFnAttribute("unsafe-fp-math").getValueAsBool();
}
bool AMDGPULibCalls::useNativeFunc(const StringRef F) const {
@@ -620,65 +630,61 @@ bool AMDGPULibCalls::fold(CallInst *CI, AliasAnalysis *AA) {
if (TDOFold(CI, FInfo))
return true;
- // Under unsafe-math, evaluate calls if possible.
- // According to Brian Sumner, we can do this for all f32 function calls
- // using host's double function calls.
- if (isUnsafeMath(CI) && evaluateCall(CI, FInfo))
- return true;
+ if (FPMathOperator *FPOp = dyn_cast<FPMathOperator>(CI)) {
+ // Under unsafe-math, evaluate calls if possible.
+ // According to Brian Sumner, we can do this for all f32 function calls
+ // using host's double function calls.
+ if (canIncreasePrecisionOfConstantFold(FPOp) && evaluateCall(CI, FInfo))
+ return true;
- // Copy fast flags from the original call.
- if (const FPMathOperator *FPOp = dyn_cast<const FPMathOperator>(CI))
+ // Copy fast flags from the original call.
B.setFastMathFlags(FPOp->getFastMathFlags());
- // Specialized optimizations for each function call
- switch (FInfo.getId()) {
- case AMDGPULibFunc::EI_RECIP:
- // skip vector function
- assert ((FInfo.getPrefix() == AMDGPULibFunc::NATIVE ||
- FInfo.getPrefix() == AMDGPULibFunc::HALF) &&
- "recip must be an either native or half function");
- return (getVecSize(FInfo) != 1) ? false : fold_recip(CI, B, FInfo);
-
- case AMDGPULibFunc::EI_DIVIDE:
- // skip vector function
- assert ((FInfo.getPrefix() == AMDGPULibFunc::NATIVE ||
- FInfo.getPrefix() == AMDGPULibFunc::HALF) &&
- "divide must be an either native or half function");
- return (getVecSize(FInfo) != 1) ? false : fold_divide(CI, B, FInfo);
-
- case AMDGPULibFunc::EI_POW:
- case AMDGPULibFunc::EI_POWR:
- case AMDGPULibFunc::EI_POWN:
- return fold_pow(CI, B, FInfo);
-
- case AMDGPULibFunc::EI_ROOTN:
- // skip vector function
- return (getVecSize(FInfo) != 1) ? false : fold_rootn(CI, B, FInfo);
-
- case AMDGPULibFunc::EI_FMA:
- case AMDGPULibFunc::EI_MAD:
- case AMDGPULibFunc::EI_NFMA:
- // skip vector function
- return (getVecSize(FInfo) != 1) ? false : fold_fma_mad(CI, B, FInfo);
-
- case AMDGPULibFunc::EI_SQRT:
- return isUnsafeMath(CI) && fold_sqrt(CI, B, FInfo);
- case AMDGPULibFunc::EI_COS:
- case AMDGPULibFunc::EI_SIN:
- if ((getArgType(FInfo) == AMDGPULibFunc::F32 ||
- getArgType(FInfo) == AMDGPULibFunc::F64)
- && (FInfo.getPrefix() == AMDGPULibFunc::NOPFX))
- return fold_sincos(CI, B, FInfo, AA);
-
- break;
- case AMDGPULibFunc::EI_READ_PIPE_2:
- case AMDGPULibFunc::EI_READ_PIPE_4:
- case AMDGPULibFunc::EI_WRITE_PIPE_2:
- case AMDGPULibFunc::EI_WRITE_PIPE_4:
- return fold_read_write_pipe(CI, B, FInfo);
-
- default:
- break;
+ // Specialized optimizations for each function call
+ switch (FInfo.getId()) {
+ case AMDGPULibFunc::EI_POW:
+ case AMDGPULibFunc::EI_POWR:
+ case AMDGPULibFunc::EI_POWN:
+ return fold_pow(FPOp, B, FInfo);
+ case AMDGPULibFunc::EI_ROOTN:
+ return fold_rootn(FPOp, B, FInfo);
+ case AMDGPULibFunc::EI_SQRT:
+ return fold_sqrt(FPOp, B, FInfo);
+ case AMDGPULibFunc::EI_COS:
+ case AMDGPULibFunc::EI_SIN:
+ return fold_sincos(FPOp, B, FInfo, AA);
+ case AMDGPULibFunc::EI_RECIP:
+ // skip vector function
+ assert((FInfo.getPrefix() == AMDGPULibFunc::NATIVE ||
+ FInfo.getPrefix() == AMDGPULibFunc::HALF) &&
+ "recip must be an either native or half function");
+ return (getVecSize(FInfo) != 1) ? false : fold_recip(CI, B, FInfo);
+
+ case AMDGPULibFunc::EI_DIVIDE:
+ // skip vector function
+ assert((FInfo.getPrefix() == AMDGPULibFunc::NATIVE ||
+ FInfo.getPrefix() == AMDGPULibFunc::HALF) &&
+ "divide must be an either native or half function");
+ return (getVecSize(FInfo) != 1) ? false : fold_divide(CI, B, FInfo);
+ case AMDGPULibFunc::EI_FMA:
+ case AMDGPULibFunc::EI_MAD:
+ case AMDGPULibFunc::EI_NFMA:
+ // skip vector function
+ return (getVecSize(FInfo) != 1) ? false : fold_fma_mad(CI, B, FInfo);
+ default:
+ break;
+ }
+ } else {
+ // Specialized optimizations for each function call
+ switch (FInfo.getId()) {
+ case AMDGPULibFunc::EI_READ_PIPE_2:
+ case AMDGPULibFunc::EI_READ_PIPE_4:
+ case AMDGPULibFunc::EI_WRITE_PIPE_2:
+ case AMDGPULibFunc::EI_WRITE_PIPE_4:
+ return fold_read_write_pipe(CI, B, FInfo);
+ default:
+ break;
+ }
}
return false;
@@ -796,7 +802,7 @@ static double log2(double V) {
}
}
-bool AMDGPULibCalls::fold_pow(CallInst *CI, IRBuilder<> &B,
+bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
const FuncInfo &FInfo) {
assert((FInfo.getId() == AMDGPULibFunc::EI_POW ||
FInfo.getId() == AMDGPULibFunc::EI_POWR ||
@@ -827,7 +833,7 @@ bool AMDGPULibCalls::fold_pow(CallInst *CI, IRBuilder<> &B,
}
// No unsafe math , no constant argument, do nothing
- if (!isUnsafeMath(CI) && !CF && !CINT && !CZero)
+ if (!isUnsafeMath(FPOp) && !CF && !CINT && !CZero)
return false;
// 0x1111111 means that we don't do anything for this call.
@@ -885,7 +891,7 @@ bool AMDGPULibCalls::fold_pow(CallInst *CI, IRBuilder<> &B,
}
}
- if (!isUnsafeMath(CI))
+ if (!isUnsafeMath(FPOp))
return false;
// Unsafe Math optimization
@@ -1079,10 +1085,14 @@ bool AMDGPULibCalls::fold_pow(CallInst *CI, IRBuilder<> &B,
return true;
}
-bool AMDGPULibCalls::fold_rootn(CallInst *CI, IRBuilder<> &B,
+bool AMDGPULibCalls::fold_rootn(FPMathOperator *FPOp, IRBuilder<> &B,
const FuncInfo &FInfo) {
- Value *opr0 = CI->getArgOperand(0);
- Value *opr1 = CI->getArgOperand(1);
+ // skip vector function
+ if (getVecSize(FInfo) != 1)
+ return false;
+
+ Value *opr0 = FPOp->getOperand(0);
+ Value *opr1 = FPOp->getOperand(1);
ConstantInt *CINT = dyn_cast<ConstantInt>(opr1);
if (!CINT) {
@@ -1188,8 +1198,11 @@ FunctionCallee AMDGPULibCalls::getNativeFunction(Module *M,
}
// fold sqrt -> native_sqrt (x)
-bool AMDGPULibCalls::fold_sqrt(CallInst *CI, IRBuilder<> &B,
+bool AMDGPULibCalls::fold_sqrt(FPMathOperator *FPOp, IRBuilder<> &B,
const FuncInfo &FInfo) {
+ if (!isUnsafeMath(FPOp))
+ return false;
+
if (getArgType(FInfo) == AMDGPULibFunc::F32 && (getVecSize(FInfo) == 1) &&
(FInfo.getPrefix() != AMDGPULibFunc::NATIVE)) {
if (FunctionCallee FPExpr = getNativeFunction(
@@ -1206,10 +1219,16 @@ bool AMDGPULibCalls::fold_sqrt(CallInst *CI, IRBuilder<> &B,
}
// fold sin, cos -> sincos.
-bool AMDGPULibCalls::fold_sincos(CallInst *CI, IRBuilder<> &B,
+bool AMDGPULibCalls::fold_sincos(FPMathOperator *FPOp, IRBuilder<> &B,
const FuncInfo &fInfo, AliasAnalysis *AA) {
assert(fInfo.getId() == AMDGPULibFunc::EI_SIN ||
fInfo.getId() == AMDGPULibFunc::EI_COS);
+
+ if ((getArgType(fInfo) != AMDGPULibFunc::F32 &&
+ getArgType(fInfo) != AMDGPULibFunc::F64) ||
+ fInfo.getPrefix() != AMDGPULibFunc::NOPFX)
+ return false;
+
bool const isSin = fInfo.getId() == AMDGPULibFunc::EI_SIN;
Value *CArgVal = CI->getArgOperand(0);
@@ -1651,6 +1670,8 @@ bool AMDGPUSimplifyLibCalls::runOnFunction(Function &F) {
if (skipFunction(F))
return false;
+ Simplifier.initFunction(F);
+
bool Changed = false;
auto AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
@@ -1675,6 +1696,7 @@ PreservedAnalyses AMDGPUSimplifyLibCallsPass::run(Function &F,
FunctionAnalysisManager &AM) {
AMDGPULibCalls Simplifier(&TM);
Simplifier.initNativeFuncs();
+ Simplifier.initFunction(F);
bool Changed = false;
auto AA = &AM.getResult<AAManager>(F);
@@ -1701,6 +1723,8 @@ bool AMDGPUUseNativeCalls::runOnFunction(Function &F) {
if (skipFunction(F) || UseNative.empty())
return false;
+ Simplifier.initFunction(F);
+
bool Changed = false;
for (auto &BB : F) {
for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ) {
@@ -1721,6 +1745,7 @@ PreservedAnalyses AMDGPUUseNativeCallsPass::run(Function &F,
AMDGPULibCalls Simplifier;
Simplifier.initNativeFuncs();
+ Simplifier.initFunction(F);
bool Changed = false;
for (auto &BB : F) {
More information about the llvm-commits
mailing list