[llvm] c2c22c6 - AMDGPU: Don't store current instruction in AMDGPULibCalls member

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 31 08:44:27 PDT 2023


Author: Matt Arsenault
Date: 2023-07-31T11:44:21-04:00
New Revision: c2c22c6c95ed9543dc55559f8921ce453eabf6b3

URL: https://github.com/llvm/llvm-project/commit/c2c22c6c95ed9543dc55559f8921ce453eabf6b3
DIFF: https://github.com/llvm/llvm-project/commit/c2c22c6c95ed9543dc55559f8921ce453eabf6b3.diff

LOG: AMDGPU: Don't store current instruction in AMDGPULibCalls member

This was adding confusing global state which was shadowed most of the
time.

https://reviews.llvm.org/D156680

Added: 
    

Modified: 
    llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp b/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
index a9dc094c2cfafb..20395abba2acda 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
@@ -113,15 +113,17 @@ class AMDGPULibCalls {
   FunctionCallee getNativeFunction(Module *M, const FuncInfo &FInfo);
 
 protected:
-  CallInst *CI;
-
   bool isUnsafeMath(const FPMathOperator *FPOp) const;
 
   bool canIncreasePrecisionOfConstantFold(const FPMathOperator *FPOp) const;
 
-  void replaceCall(Value *With) {
-    CI->replaceAllUsesWith(With);
-    CI->eraseFromParent();
+  static void replaceCall(Instruction *I, Value *With) {
+    I->replaceAllUsesWith(With);
+    I->eraseFromParent();
+  }
+
+  static void replaceCall(FPMathOperator *I, Value *With) {
+    replaceCall(cast<Instruction>(I), With);
   }
 
 public:
@@ -501,7 +503,7 @@ bool AMDGPULibCalls::sincosUseNative(CallInst *aCI, const FuncInfo &FInfo) {
       DEBUG_WITH_TYPE("usenative", dbgs() << "<useNative> replace " << *aCI
                                           << " with native version of sin/cos");
 
-      replaceCall(sinval);
+      replaceCall(aCI, sinval);
       return true;
     }
   }
@@ -509,7 +511,6 @@ bool AMDGPULibCalls::sincosUseNative(CallInst *aCI, const FuncInfo &FInfo) {
 }
 
 bool AMDGPULibCalls::useNative(CallInst *aCI) {
-  CI = aCI;
   Function *Callee = aCI->getCalledFunction();
   if (!Callee || aCI->isNoBuiltin())
     return false;
@@ -601,7 +602,6 @@ bool AMDGPULibCalls::fold_read_write_pipe(CallInst *CI, IRBuilder<> &B,
 
 // This function returns false if no change; return true otherwise.
 bool AMDGPULibCalls::fold(CallInst *CI, AliasAnalysis *AA) {
-  this->CI = CI;
   Function *Callee = CI->getCalledFunction();
   // Ignore indirect calls.
   if (!Callee || CI->isNoBuiltin())
@@ -733,7 +733,7 @@ bool AMDGPULibCalls::TDOFold(CallInst *CI, const FuncInfo &FInfo) {
         nval = ConstantDataVector::get(context, tmp);
       }
       LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *nval << "\n");
-      replaceCall(nval);
+      replaceCall(CI, nval);
       return true;
     }
   } else {
@@ -743,7 +743,7 @@ bool AMDGPULibCalls::TDOFold(CallInst *CI, const FuncInfo &FInfo) {
         if (CF->isExactlyValue(tr[i].input)) {
           Value *nval = ConstantFP::get(CF->getType(), tr[i].result);
           LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *nval << "\n");
-          replaceCall(nval);
+          replaceCall(CI, nval);
           return true;
         }
       }
@@ -765,7 +765,7 @@ bool AMDGPULibCalls::fold_recip(CallInst *CI, IRBuilder<> &B,
                                opr0,
                                "recip2div");
     LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *nval << "\n");
-    replaceCall(nval);
+    replaceCall(CI, nval);
     return true;
   }
   return false;
@@ -786,7 +786,7 @@ bool AMDGPULibCalls::fold_divide(CallInst *CI, IRBuilder<> &B,
     Value *nval1 = B.CreateFDiv(ConstantFP::get(opr1->getType(), 1.0),
                                 opr1, "__div2recip");
     Value *nval  = B.CreateFMul(opr0, nval1, "__div2mul");
-    replaceCall(nval);
+    replaceCall(CI, nval);
     return true;
   }
   return false;
@@ -813,8 +813,8 @@ bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
   ConstantFP *CF;
   ConstantInt *CINT;
   Type *eltType;
-  Value *opr0 = CI->getArgOperand(0);
-  Value *opr1 = CI->getArgOperand(1);
+  Value *opr0 = FPOp->getOperand(0);
+  Value *opr1 = FPOp->getOperand(1);
   ConstantAggregateZero *CZero = dyn_cast<ConstantAggregateZero>(opr1);
 
   if (getVecSize(FInfo) == 1) {
@@ -841,37 +841,37 @@ bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
 
   if ((CF && CF->isZero()) || (CINT && ci_opr1 == 0) || CZero) {
     //  pow/powr/pown(x, 0) == 1
-    LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> 1\n");
+    LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> 1\n");
     Constant *cnval = ConstantFP::get(eltType, 1.0);
     if (getVecSize(FInfo) > 1) {
       cnval = ConstantDataVector::getSplat(getVecSize(FInfo), cnval);
     }
-    replaceCall(cnval);
+    replaceCall(FPOp, cnval);
     return true;
   }
   if ((CF && CF->isExactlyValue(1.0)) || (CINT && ci_opr1 == 1)) {
     // pow/powr/pown(x, 1.0) = x
-    LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr0 << "\n");
-    replaceCall(opr0);
+    LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << *opr0 << "\n");
+    replaceCall(FPOp, opr0);
     return true;
   }
   if ((CF && CF->isExactlyValue(2.0)) || (CINT && ci_opr1 == 2)) {
     // pow/powr/pown(x, 2.0) = x*x
-    LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr0 << " * " << *opr0
-                      << "\n");
+    LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << *opr0 << " * "
+                      << *opr0 << "\n");
     Value *nval = B.CreateFMul(opr0, opr0, "__pow2");
-    replaceCall(nval);
+    replaceCall(FPOp, nval);
     return true;
   }
   if ((CF && CF->isExactlyValue(-1.0)) || (CINT && ci_opr1 == -1)) {
     // pow/powr/pown(x, -1.0) = 1.0/x
-    LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> 1 / " << *opr0 << "\n");
+    LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> 1 / " << *opr0 << "\n");
     Constant *cnval = ConstantFP::get(eltType, 1.0);
     if (getVecSize(FInfo) > 1) {
       cnval = ConstantDataVector::getSplat(getVecSize(FInfo), cnval);
     }
     Value *nval = B.CreateFDiv(cnval, opr0, "__powrecip");
-    replaceCall(nval);
+    replaceCall(FPOp, nval);
     return true;
   }
 
@@ -882,11 +882,11 @@ bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
             getFunction(M, AMDGPULibFunc(issqrt ? AMDGPULibFunc::EI_SQRT
                                                 : AMDGPULibFunc::EI_RSQRT,
                                          FInfo))) {
-      LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << FInfo.getName()
+      LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << FInfo.getName()
                         << '(' << *opr0 << ")\n");
       Value *nval = CreateCallEx(B,FPExpr, opr0, issqrt ? "__pow2sqrt"
                                                         : "__pow2rsqrt");
-      replaceCall(nval);
+      replaceCall(FPOp, nval);
       return true;
     }
   }
@@ -939,10 +939,10 @@ bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
       }
       nval = B.CreateFDiv(cnval, nval, "__1powprod");
     }
-    LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> "
+    LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> "
                       << ((ci_opr1 < 0) ? "1/prod(" : "prod(") << *opr0
                       << ")\n");
-    replaceCall(nval);
+    replaceCall(FPOp, nval);
     return true;
   }
 
@@ -1066,7 +1066,7 @@ bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
     if (const auto *vTy = dyn_cast<FixedVectorType>(rTy))
       nTy = FixedVectorType::get(nTyS, vTy);
     unsigned size = nTy->getScalarSizeInBits();
-    opr_n = CI->getArgOperand(1);
+    opr_n = FPOp->getOperand(1);
     if (opr_n->getType()->isIntegerTy())
       opr_n = B.CreateZExtOrBitCast(opr_n, nTy, "__ytou");
     else
@@ -1078,9 +1078,9 @@ bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
     nval = B.CreateBitCast(nval, opr0->getType());
   }
 
-  LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> "
+  LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> "
                     << "exp2(" << *opr1 << " * log2(" << *opr0 << "))\n");
-  replaceCall(nval);
+  replaceCall(FPOp, nval);
 
   return true;
 }
@@ -1100,43 +1100,44 @@ bool AMDGPULibCalls::fold_rootn(FPMathOperator *FPOp, IRBuilder<> &B,
   }
   int ci_opr1 = (int)CINT->getSExtValue();
   if (ci_opr1 == 1) {  // rootn(x, 1) = x
-    LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr0 << "\n");
-    replaceCall(opr0);
+    LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << *opr0 << "\n");
+    replaceCall(FPOp, opr0);
     return true;
   }
-  if (ci_opr1 == 2) {  // rootn(x, 2) = sqrt(x)
-    Module *M = CI->getModule();
+
+  Module *M = B.GetInsertBlock()->getModule();
+  if (ci_opr1 == 2) { // rootn(x, 2) = sqrt(x)
     if (FunctionCallee FPExpr =
             getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_SQRT, FInfo))) {
-      LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> sqrt(" << *opr0 << ")\n");
+      LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> sqrt(" << *opr0
+                        << ")\n");
       Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2sqrt");
-      replaceCall(nval);
+      replaceCall(FPOp, nval);
       return true;
     }
   } else if (ci_opr1 == 3) { // rootn(x, 3) = cbrt(x)
-    Module *M = CI->getModule();
     if (FunctionCallee FPExpr =
             getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_CBRT, FInfo))) {
-      LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> cbrt(" << *opr0 << ")\n");
+      LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> cbrt(" << *opr0
+                        << ")\n");
       Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2cbrt");
-      replaceCall(nval);
+      replaceCall(FPOp, nval);
       return true;
     }
   } else if (ci_opr1 == -1) { // rootn(x, -1) = 1.0/x
-    LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> 1.0 / " << *opr0 << "\n");
+    LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> 1.0 / " << *opr0 << "\n");
     Value *nval = B.CreateFDiv(ConstantFP::get(opr0->getType(), 1.0),
                                opr0,
                                "__rootn2div");
-    replaceCall(nval);
+    replaceCall(FPOp, nval);
     return true;
-  } else if (ci_opr1 == -2) {  // rootn(x, -2) = rsqrt(x)
-    Module *M = CI->getModule();
+  } else if (ci_opr1 == -2) { // rootn(x, -2) = rsqrt(x)
     if (FunctionCallee FPExpr =
             getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_RSQRT, FInfo))) {
-      LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> rsqrt(" << *opr0
+      LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> rsqrt(" << *opr0
                         << ")\n");
       Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2rsqrt");
-      replaceCall(nval);
+      replaceCall(FPOp, nval);
       return true;
     }
   }
@@ -1154,7 +1155,7 @@ bool AMDGPULibCalls::fold_fma_mad(CallInst *CI, IRBuilder<> &B,
   if ((CF0 && CF0->isZero()) || (CF1 && CF1->isZero())) {
     // fma/mad(a, b, c) = c if a=0 || b=0
     LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr2 << "\n");
-    replaceCall(opr2);
+    replaceCall(CI, opr2);
     return true;
   }
   if (CF0 && CF0->isExactlyValue(1.0f)) {
@@ -1162,7 +1163,7 @@ bool AMDGPULibCalls::fold_fma_mad(CallInst *CI, IRBuilder<> &B,
     LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr1 << " + " << *opr2
                       << "\n");
     Value *nval = B.CreateFAdd(opr1, opr2, "fmaadd");
-    replaceCall(nval);
+    replaceCall(CI, nval);
     return true;
   }
   if (CF1 && CF1->isExactlyValue(1.0f)) {
@@ -1170,7 +1171,7 @@ bool AMDGPULibCalls::fold_fma_mad(CallInst *CI, IRBuilder<> &B,
     LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr0 << " + " << *opr2
                       << "\n");
     Value *nval = B.CreateFAdd(opr0, opr2, "fmaadd");
-    replaceCall(nval);
+    replaceCall(CI, nval);
     return true;
   }
   if (ConstantFP *CF = dyn_cast<ConstantFP>(opr2)) {
@@ -1179,7 +1180,7 @@ bool AMDGPULibCalls::fold_fma_mad(CallInst *CI, IRBuilder<> &B,
       LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr0 << " * "
                         << *opr1 << "\n");
       Value *nval = B.CreateFMul(opr0, opr1, "fmamul");
-      replaceCall(nval);
+      replaceCall(CI, nval);
       return true;
     }
   }
@@ -1205,13 +1206,15 @@ bool AMDGPULibCalls::fold_sqrt(FPMathOperator *FPOp, IRBuilder<> &B,
 
   if (getArgType(FInfo) == AMDGPULibFunc::F32 && (getVecSize(FInfo) == 1) &&
       (FInfo.getPrefix() != AMDGPULibFunc::NATIVE)) {
+    Module *M = B.GetInsertBlock()->getModule();
+
     if (FunctionCallee FPExpr = getNativeFunction(
-            CI->getModule(), AMDGPULibFunc(AMDGPULibFunc::EI_SQRT, FInfo))) {
-      Value *opr0 = CI->getArgOperand(0);
-      LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> "
+            M, AMDGPULibFunc(AMDGPULibFunc::EI_SQRT, FInfo))) {
+      Value *opr0 = FPOp->getOperand(0);
+      LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> "
                         << "sqrt(" << *opr0 << ")\n");
       Value *nval = CreateCallEx(B,FPExpr, opr0, "__sqrt");
-      replaceCall(nval);
+      replaceCall(FPOp, nval);
       return true;
     }
   }
@@ -1231,7 +1234,8 @@ bool AMDGPULibCalls::fold_sincos(FPMathOperator *FPOp, IRBuilder<> &B,
 
   bool const isSin = fInfo.getId() == AMDGPULibFunc::EI_SIN;
 
-  Value *CArgVal = CI->getArgOperand(0);
+  Value *CArgVal = FPOp->getOperand(0);
+  CallInst *CI = cast<CallInst>(FPOp);
   BasicBlock * const CBB = CI->getParent();
 
   int const MaxScan = 30;
@@ -1247,7 +1251,7 @@ bool AMDGPULibCalls::fold_sincos(FPMathOperator *FPOp, IRBuilder<> &B,
         CArgVal->replaceAllUsesWith(AvailableVal);
         if (CArgVal->getNumUses() == 0)
           LI->eraseFromParent();
-        CArgVal = CI->getArgOperand(0);
+        CArgVal = FPOp->getOperand(0);
       }
     }
   }
@@ -1617,12 +1621,12 @@ bool AMDGPULibCalls::evaluateCall(CallInst *aCI, const FuncInfo &FInfo) {
     }
   }
 
-  LLVMContext &context = CI->getParent()->getParent()->getContext();
+  LLVMContext &context = aCI->getContext();
   Constant *nval0, *nval1;
   if (FuncVecSize == 1) {
-    nval0 = ConstantFP::get(CI->getType(), DVal0[0]);
+    nval0 = ConstantFP::get(aCI->getType(), DVal0[0]);
     if (hasTwoResults)
-      nval1 = ConstantFP::get(CI->getType(), DVal1[0]);
+      nval1 = ConstantFP::get(aCI->getType(), DVal1[0]);
   } else {
     if (getArgType(FInfo) == AMDGPULibFunc::F32) {
       SmallVector <float, 0> FVal0, FVal1;
@@ -1653,7 +1657,7 @@ bool AMDGPULibCalls::evaluateCall(CallInst *aCI, const FuncInfo &FInfo) {
     new StoreInst(nval1, aCI->getArgOperand(1), aCI);
   }
 
-  replaceCall(nval0);
+  replaceCall(aCI, nval0);
   return true;
 }
 


        


More information about the llvm-commits mailing list