[llvm] 8f38138 - AMDGPU: Refactor libcall simplify to help with future refined fast math flag usage

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 31 08:23:21 PDT 2023


Author: Matt Arsenault
Date: 2023-07-31T11:23:12-04:00
New Revision: 8f381380905c159bafb8db3b9ce699ec182e892c

URL: https://github.com/llvm/llvm-project/commit/8f381380905c159bafb8db3b9ce699ec182e892c
DIFF: https://github.com/llvm/llvm-project/commit/8f381380905c159bafb8db3b9ce699ec182e892c.diff

LOG: AMDGPU: Refactor libcall simplify to help with future refined fast math flag usage

https://reviews.llvm.org/D156678

Added: 
    

Modified: 
    llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp b/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
index f6f187923e61ff..a9dc094c2cfafb 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
@@ -51,6 +51,8 @@ class AMDGPULibCalls {
 
   const TargetMachine *TM;
 
+  bool UnsafeFPMath = false;
+
   // -fuse-native.
   bool AllNative = false;
 
@@ -73,10 +75,10 @@ class AMDGPULibCalls {
   bool fold_divide(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
 
   // pow/powr/pown
-  bool fold_pow(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
+  bool fold_pow(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo);
 
   // rootn
-  bool fold_rootn(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
+  bool fold_rootn(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo);
 
   // fma/mad
   bool fold_fma_mad(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
@@ -90,10 +92,10 @@ class AMDGPULibCalls {
   bool evaluateCall(CallInst *aCI, const FuncInfo &FInfo);
 
   // sqrt
-  bool fold_sqrt(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
+  bool fold_sqrt(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo);
 
   // sin/cos
-  bool fold_sincos(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo,
+  bool fold_sincos(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo,
                    AliasAnalysis *AA);
 
   // __read_pipe/__write_pipe
@@ -113,7 +115,9 @@ class AMDGPULibCalls {
 protected:
   CallInst *CI;
 
-  bool isUnsafeMath(const CallInst *CI) const;
+  bool isUnsafeMath(const FPMathOperator *FPOp) const;
+
+  bool canIncreasePrecisionOfConstantFold(const FPMathOperator *FPOp) const;
 
   void replaceCall(Value *With) {
     CI->replaceAllUsesWith(With);
@@ -125,6 +129,7 @@ class AMDGPULibCalls {
 
   bool fold(CallInst *CI, AliasAnalysis *AA = nullptr);
 
+  void initFunction(const Function &F);
   void initNativeFuncs();
 
   // Replace a normal math function call with that native version
@@ -445,13 +450,18 @@ bool AMDGPULibCalls::parseFunctionName(const StringRef &FMangledName,
   return AMDGPULibFunc::parse(FMangledName, FInfo);
 }
 
-bool AMDGPULibCalls::isUnsafeMath(const CallInst *CI) const {
-  if (auto Op = dyn_cast<FPMathOperator>(CI))
-    if (Op->isFast())
-      return true;
-  const Function *F = CI->getParent()->getParent();
-  Attribute Attr = F->getFnAttribute("unsafe-fp-math");
-  return Attr.getValueAsBool();
+bool AMDGPULibCalls::isUnsafeMath(const FPMathOperator *FPOp) const {
+  return UnsafeFPMath || FPOp->isFast();
+}
+
+bool AMDGPULibCalls::canIncreasePrecisionOfConstantFold(
+    const FPMathOperator *FPOp) const {
+  // TODO: Refine to approxFunc or contract
+  return isUnsafeMath(FPOp);
+}
+
+void AMDGPULibCalls::initFunction(const Function &F) {
+  UnsafeFPMath = F.getFnAttribute("unsafe-fp-math").getValueAsBool();
 }
 
 bool AMDGPULibCalls::useNativeFunc(const StringRef F) const {
@@ -620,65 +630,61 @@ bool AMDGPULibCalls::fold(CallInst *CI, AliasAnalysis *AA) {
   if (TDOFold(CI, FInfo))
     return true;
 
-  // Under unsafe-math, evaluate calls if possible.
-  // According to Brian Sumner, we can do this for all f32 function calls
-  // using host's double function calls.
-  if (isUnsafeMath(CI) && evaluateCall(CI, FInfo))
-    return true;
+  if (FPMathOperator *FPOp = dyn_cast<FPMathOperator>(CI)) {
+    // Under unsafe-math, evaluate calls if possible.
+    // According to Brian Sumner, we can do this for all f32 function calls
+    // using host's double function calls.
+    if (canIncreasePrecisionOfConstantFold(FPOp) && evaluateCall(CI, FInfo))
+      return true;
 
-  // Copy fast flags from the original call.
-  if (const FPMathOperator *FPOp = dyn_cast<const FPMathOperator>(CI))
+    // Copy fast flags from the original call.
     B.setFastMathFlags(FPOp->getFastMathFlags());
 
-  // Specialized optimizations for each function call
-  switch (FInfo.getId()) {
-  case AMDGPULibFunc::EI_RECIP:
-    // skip vector function
-    assert ((FInfo.getPrefix() == AMDGPULibFunc::NATIVE ||
-             FInfo.getPrefix() == AMDGPULibFunc::HALF) &&
-            "recip must be an either native or half function");
-    return (getVecSize(FInfo) != 1) ? false : fold_recip(CI, B, FInfo);
-
-  case AMDGPULibFunc::EI_DIVIDE:
-    // skip vector function
-    assert ((FInfo.getPrefix() == AMDGPULibFunc::NATIVE ||
-             FInfo.getPrefix() == AMDGPULibFunc::HALF) &&
-            "divide must be an either native or half function");
-    return (getVecSize(FInfo) != 1) ? false : fold_divide(CI, B, FInfo);
-
-  case AMDGPULibFunc::EI_POW:
-  case AMDGPULibFunc::EI_POWR:
-  case AMDGPULibFunc::EI_POWN:
-    return fold_pow(CI, B, FInfo);
-
-  case AMDGPULibFunc::EI_ROOTN:
-    // skip vector function
-    return (getVecSize(FInfo) != 1) ? false : fold_rootn(CI, B, FInfo);
-
-  case AMDGPULibFunc::EI_FMA:
-  case AMDGPULibFunc::EI_MAD:
-  case AMDGPULibFunc::EI_NFMA:
-    // skip vector function
-    return (getVecSize(FInfo) != 1) ? false : fold_fma_mad(CI, B, FInfo);
-
-  case AMDGPULibFunc::EI_SQRT:
-    return isUnsafeMath(CI) && fold_sqrt(CI, B, FInfo);
-  case AMDGPULibFunc::EI_COS:
-  case AMDGPULibFunc::EI_SIN:
-    if ((getArgType(FInfo) == AMDGPULibFunc::F32 ||
-         getArgType(FInfo) == AMDGPULibFunc::F64)
-        && (FInfo.getPrefix() == AMDGPULibFunc::NOPFX))
-      return fold_sincos(CI, B, FInfo, AA);
-
-    break;
-  case AMDGPULibFunc::EI_READ_PIPE_2:
-  case AMDGPULibFunc::EI_READ_PIPE_4:
-  case AMDGPULibFunc::EI_WRITE_PIPE_2:
-  case AMDGPULibFunc::EI_WRITE_PIPE_4:
-    return fold_read_write_pipe(CI, B, FInfo);
-
-  default:
-    break;
+    // Specialized optimizations for each function call
+    switch (FInfo.getId()) {
+    case AMDGPULibFunc::EI_POW:
+    case AMDGPULibFunc::EI_POWR:
+    case AMDGPULibFunc::EI_POWN:
+      return fold_pow(FPOp, B, FInfo);
+    case AMDGPULibFunc::EI_ROOTN:
+      return fold_rootn(FPOp, B, FInfo);
+    case AMDGPULibFunc::EI_SQRT:
+      return fold_sqrt(FPOp, B, FInfo);
+    case AMDGPULibFunc::EI_COS:
+    case AMDGPULibFunc::EI_SIN:
+      return fold_sincos(FPOp, B, FInfo, AA);
+    case AMDGPULibFunc::EI_RECIP:
+      // skip vector function
+      assert((FInfo.getPrefix() == AMDGPULibFunc::NATIVE ||
+              FInfo.getPrefix() == AMDGPULibFunc::HALF) &&
+             "recip must be an either native or half function");
+      return (getVecSize(FInfo) != 1) ? false : fold_recip(CI, B, FInfo);
+
+    case AMDGPULibFunc::EI_DIVIDE:
+      // skip vector function
+      assert((FInfo.getPrefix() == AMDGPULibFunc::NATIVE ||
+              FInfo.getPrefix() == AMDGPULibFunc::HALF) &&
+             "divide must be an either native or half function");
+      return (getVecSize(FInfo) != 1) ? false : fold_divide(CI, B, FInfo);
+    case AMDGPULibFunc::EI_FMA:
+    case AMDGPULibFunc::EI_MAD:
+    case AMDGPULibFunc::EI_NFMA:
+      // skip vector function
+      return (getVecSize(FInfo) != 1) ? false : fold_fma_mad(CI, B, FInfo);
+    default:
+      break;
+    }
+  } else {
+    // Specialized optimizations for each function call
+    switch (FInfo.getId()) {
+    case AMDGPULibFunc::EI_READ_PIPE_2:
+    case AMDGPULibFunc::EI_READ_PIPE_4:
+    case AMDGPULibFunc::EI_WRITE_PIPE_2:
+    case AMDGPULibFunc::EI_WRITE_PIPE_4:
+      return fold_read_write_pipe(CI, B, FInfo);
+    default:
+      break;
+    }
   }
 
   return false;
@@ -796,7 +802,7 @@ static double log2(double V) {
 }
 }
 
-bool AMDGPULibCalls::fold_pow(CallInst *CI, IRBuilder<> &B,
+bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
                               const FuncInfo &FInfo) {
   assert((FInfo.getId() == AMDGPULibFunc::EI_POW ||
           FInfo.getId() == AMDGPULibFunc::EI_POWR ||
@@ -827,7 +833,7 @@ bool AMDGPULibCalls::fold_pow(CallInst *CI, IRBuilder<> &B,
   }
 
   // No unsafe math , no constant argument, do nothing
-  if (!isUnsafeMath(CI) && !CF && !CINT && !CZero)
+  if (!isUnsafeMath(FPOp) && !CF && !CINT && !CZero)
     return false;
 
   // 0x1111111 means that we don't do anything for this call.
@@ -885,7 +891,7 @@ bool AMDGPULibCalls::fold_pow(CallInst *CI, IRBuilder<> &B,
     }
   }
 
-  if (!isUnsafeMath(CI))
+  if (!isUnsafeMath(FPOp))
     return false;
 
   // Unsafe Math optimization
@@ -1079,10 +1085,14 @@ bool AMDGPULibCalls::fold_pow(CallInst *CI, IRBuilder<> &B,
   return true;
 }
 
-bool AMDGPULibCalls::fold_rootn(CallInst *CI, IRBuilder<> &B,
+bool AMDGPULibCalls::fold_rootn(FPMathOperator *FPOp, IRBuilder<> &B,
                                 const FuncInfo &FInfo) {
-  Value *opr0 = CI->getArgOperand(0);
-  Value *opr1 = CI->getArgOperand(1);
+  // skip vector function
+  if (getVecSize(FInfo) != 1)
+    return false;
+
+  Value *opr0 = FPOp->getOperand(0);
+  Value *opr1 = FPOp->getOperand(1);
 
   ConstantInt *CINT = dyn_cast<ConstantInt>(opr1);
   if (!CINT) {
@@ -1188,8 +1198,11 @@ FunctionCallee AMDGPULibCalls::getNativeFunction(Module *M,
 }
 
 // fold sqrt -> native_sqrt (x)
-bool AMDGPULibCalls::fold_sqrt(CallInst *CI, IRBuilder<> &B,
+bool AMDGPULibCalls::fold_sqrt(FPMathOperator *FPOp, IRBuilder<> &B,
                                const FuncInfo &FInfo) {
+  if (!isUnsafeMath(FPOp))
+    return false;
+
   if (getArgType(FInfo) == AMDGPULibFunc::F32 && (getVecSize(FInfo) == 1) &&
       (FInfo.getPrefix() != AMDGPULibFunc::NATIVE)) {
     if (FunctionCallee FPExpr = getNativeFunction(
@@ -1206,10 +1219,16 @@ bool AMDGPULibCalls::fold_sqrt(CallInst *CI, IRBuilder<> &B,
 }
 
 // fold sin, cos -> sincos.
-bool AMDGPULibCalls::fold_sincos(CallInst *CI, IRBuilder<> &B,
+bool AMDGPULibCalls::fold_sincos(FPMathOperator *FPOp, IRBuilder<> &B,
                                  const FuncInfo &fInfo, AliasAnalysis *AA) {
   assert(fInfo.getId() == AMDGPULibFunc::EI_SIN ||
          fInfo.getId() == AMDGPULibFunc::EI_COS);
+
+  if ((getArgType(fInfo) != AMDGPULibFunc::F32 &&
+       getArgType(fInfo) != AMDGPULibFunc::F64) ||
+      fInfo.getPrefix() != AMDGPULibFunc::NOPFX)
+    return false;
+
   bool const isSin = fInfo.getId() == AMDGPULibFunc::EI_SIN;
 
   Value *CArgVal = CI->getArgOperand(0);
@@ -1651,6 +1670,8 @@ bool AMDGPUSimplifyLibCalls::runOnFunction(Function &F) {
   if (skipFunction(F))
     return false;
 
+  Simplifier.initFunction(F);
+
   bool Changed = false;
   auto AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
 
@@ -1675,6 +1696,7 @@ PreservedAnalyses AMDGPUSimplifyLibCallsPass::run(Function &F,
                                                   FunctionAnalysisManager &AM) {
   AMDGPULibCalls Simplifier(&TM);
   Simplifier.initNativeFuncs();
+  Simplifier.initFunction(F);
 
   bool Changed = false;
   auto AA = &AM.getResult<AAManager>(F);
@@ -1701,6 +1723,8 @@ bool AMDGPUUseNativeCalls::runOnFunction(Function &F) {
   if (skipFunction(F) || UseNative.empty())
     return false;
 
+  Simplifier.initFunction(F);
+
   bool Changed = false;
   for (auto &BB : F) {
     for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ) {
@@ -1721,6 +1745,7 @@ PreservedAnalyses AMDGPUUseNativeCallsPass::run(Function &F,
 
   AMDGPULibCalls Simplifier;
   Simplifier.initNativeFuncs();
+  Simplifier.initFunction(F);
 
   bool Changed = false;
   for (auto &BB : F) {


        


More information about the llvm-commits mailing list