[llvm] [NFC][DXIL] move replace/erase in DXIL intrinsic expansion to caller (PR #104626)

via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 16 11:22:12 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-directx

Author: Greg Roth (pow2clk)

<details>
<summary>Changes</summary>

All expansions end with replacing the previous inrinsic with the new expansion and erasing the old one. By moving this operation to the caller, these expansion functions can be called in more contexts and a small amount of duplicated code is consolidated.

Pre-req for #<!-- -->88056

---
Full diff: https://github.com/llvm/llvm-project/pull/104626.diff


1 Files Affected:

- (modified) llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp (+60-75) 


``````````diff
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index e63633b8a1e1ab..a4cb1c3d575473 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -51,7 +51,7 @@ static bool isIntrinsicExpansion(Function &F) {
   return false;
 }
 
-static bool expandAbs(CallInst *Orig) {
+static Value *expandAbs(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
   IRBuilder<> Builder(Orig->getParent());
   Builder.SetInsertPoint(Orig);
@@ -66,12 +66,10 @@ static bool expandAbs(CallInst *Orig) {
   auto *V = Builder.CreateSub(Zero, X);
   auto *MaxCall =
       Builder.CreateIntrinsic(Ty, Intrinsic::smax, {X, V}, nullptr, "dx.max");
-  Orig->replaceAllUsesWith(MaxCall);
-  Orig->eraseFromParent();
-  return true;
+  return MaxCall;
 }
 
-static bool expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
+static Value *expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
   assert(DotIntrinsic == Intrinsic::dx_sdot ||
          DotIntrinsic == Intrinsic::dx_udot);
   Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot
@@ -97,12 +95,10 @@ static bool expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
                                      ArrayRef<Value *>{Elt0, Elt1, Result},
                                      nullptr, "dx.mad");
   }
-  Orig->replaceAllUsesWith(Result);
-  Orig->eraseFromParent();
-  return true;
+  return Result;
 }
 
-static bool expandExpIntrinsic(CallInst *Orig) {
+static Value *expandExpIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
   IRBuilder<> Builder(Orig->getParent());
   Builder.SetInsertPoint(Orig);
@@ -119,23 +115,21 @@ static bool expandExpIntrinsic(CallInst *Orig) {
       Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {NewX}, nullptr, "dx.exp2");
   Exp2Call->setTailCall(Orig->isTailCall());
   Exp2Call->setAttributes(Orig->getAttributes());
-  Orig->replaceAllUsesWith(Exp2Call);
-  Orig->eraseFromParent();
-  return true;
+  return Exp2Call;
 }
 
-static bool expandAnyIntrinsic(CallInst *Orig) {
+static Value *expandAnyIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
   IRBuilder<> Builder(Orig->getParent());
   Builder.SetInsertPoint(Orig);
   Type *Ty = X->getType();
   Type *EltTy = Ty->getScalarType();
 
+  Value *Result = nullptr;
   if (!Ty->isVectorTy()) {
-    Value *Cond = EltTy->isFloatingPointTy()
-                      ? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0))
-                      : Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0));
-    Orig->replaceAllUsesWith(Cond);
+    Result = EltTy->isFloatingPointTy()
+                 ? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0))
+                 : Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0));
   } else {
     auto *XVec = dyn_cast<FixedVectorType>(Ty);
     Value *Cond =
@@ -148,18 +142,16 @@ static bool expandAnyIntrinsic(CallInst *Orig) {
                   X, ConstantVector::getSplat(
                          ElementCount::getFixed(XVec->getNumElements()),
                          ConstantInt::get(EltTy, 0)));
-    Value *Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
+    Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
     for (unsigned I = 1; I < XVec->getNumElements(); I++) {
       Value *Elt = Builder.CreateExtractElement(Cond, I);
       Result = Builder.CreateOr(Result, Elt);
     }
-    Orig->replaceAllUsesWith(Result);
   }
-  Orig->eraseFromParent();
-  return true;
+  return Result;
 }
 
-static bool expandLengthIntrinsic(CallInst *Orig) {
+static Value *expandLengthIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
   IRBuilder<> Builder(Orig->getParent());
   Builder.SetInsertPoint(Orig);
@@ -182,15 +174,11 @@ static bool expandLengthIntrinsic(CallInst *Orig) {
     Value *Mul = Builder.CreateFMul(Elt, Elt);
     Sum = Builder.CreateFAdd(Sum, Mul);
   }
-  Value *Result = Builder.CreateIntrinsic(
-      EltTy, Intrinsic::sqrt, ArrayRef<Value *>{Sum}, nullptr, "elt.sqrt");
-
-  Orig->replaceAllUsesWith(Result);
-  Orig->eraseFromParent();
-  return true;
+  return Builder.CreateIntrinsic(EltTy, Intrinsic::sqrt, ArrayRef<Value *>{Sum},
+                                 nullptr, "elt.sqrt");
 }
 
-static bool expandLerpIntrinsic(CallInst *Orig) {
+static Value *expandLerpIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
   Value *Y = Orig->getOperand(1);
   Value *S = Orig->getOperand(2);
@@ -198,14 +186,11 @@ static bool expandLerpIntrinsic(CallInst *Orig) {
   Builder.SetInsertPoint(Orig);
   auto *V = Builder.CreateFSub(Y, X);
   V = Builder.CreateFMul(S, V);
-  auto *Result = Builder.CreateFAdd(X, V, "dx.lerp");
-  Orig->replaceAllUsesWith(Result);
-  Orig->eraseFromParent();
-  return true;
+  return Builder.CreateFAdd(X, V, "dx.lerp");
 }
 
-static bool expandLogIntrinsic(CallInst *Orig,
-                               float LogConstVal = numbers::ln2f) {
+static Value *expandLogIntrinsic(CallInst *Orig,
+                                 float LogConstVal = numbers::ln2f) {
   Value *X = Orig->getOperand(0);
   IRBuilder<> Builder(Orig->getParent());
   Builder.SetInsertPoint(Orig);
@@ -221,16 +206,13 @@ static bool expandLogIntrinsic(CallInst *Orig,
       Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
   Log2Call->setTailCall(Orig->isTailCall());
   Log2Call->setAttributes(Orig->getAttributes());
-  auto *Result = Builder.CreateFMul(Ln2Const, Log2Call);
-  Orig->replaceAllUsesWith(Result);
-  Orig->eraseFromParent();
-  return true;
+  return Builder.CreateFMul(Ln2Const, Log2Call);
 }
-static bool expandLog10Intrinsic(CallInst *Orig) {
+static Value *expandLog10Intrinsic(CallInst *Orig) {
   return expandLogIntrinsic(Orig, numbers::ln2f / numbers::ln10f);
 }
 
-static bool expandNormalizeIntrinsic(CallInst *Orig) {
+static Value *expandNormalizeIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
   Type *Ty = Orig->getType();
   Type *EltTy = Ty->getScalarType();
@@ -245,11 +227,7 @@ static bool expandNormalizeIntrinsic(CallInst *Orig) {
         report_fatal_error(Twine("Invalid input scalar: length is zero"),
                            /* gen_crash_diag=*/false);
     }
-    Value *Result = Builder.CreateFDiv(X, X);
-
-    Orig->replaceAllUsesWith(Result);
-    Orig->eraseFromParent();
-    return true;
+    return Builder.CreateFDiv(X, X);
   }
 
   unsigned XVecSize = XVec->getNumElements();
@@ -291,14 +269,10 @@ static bool expandNormalizeIntrinsic(CallInst *Orig) {
                                                 nullptr, "dx.rsqrt");
 
   Value *MultiplicandVec = Builder.CreateVectorSplat(XVecSize, Multiplicand);
-  Value *Result = Builder.CreateFMul(X, MultiplicandVec);
-
-  Orig->replaceAllUsesWith(Result);
-  Orig->eraseFromParent();
-  return true;
+  return Builder.CreateFMul(X, MultiplicandVec);
 }
 
-static bool expandPowIntrinsic(CallInst *Orig) {
+static Value *expandPowIntrinsic(CallInst *Orig) {
 
   Value *X = Orig->getOperand(0);
   Value *Y = Orig->getOperand(1);
@@ -313,9 +287,7 @@ static bool expandPowIntrinsic(CallInst *Orig) {
       Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {Mul}, nullptr, "elt.exp2");
   Exp2Call->setTailCall(Orig->isTailCall());
   Exp2Call->setAttributes(Orig->getAttributes());
-  Orig->replaceAllUsesWith(Exp2Call);
-  Orig->eraseFromParent();
-  return true;
+  return Exp2Call;
 }
 
 static Intrinsic::ID getMaxForClamp(Type *ElemTy,
@@ -344,7 +316,8 @@ static Intrinsic::ID getMinForClamp(Type *ElemTy,
   return Intrinsic::minnum;
 }
 
-static bool expandClampIntrinsic(CallInst *Orig, Intrinsic::ID ClampIntrinsic) {
+static Value *expandClampIntrinsic(CallInst *Orig,
+                                   Intrinsic::ID ClampIntrinsic) {
   Value *X = Orig->getOperand(0);
   Value *Min = Orig->getOperand(1);
   Value *Max = Orig->getOperand(2);
@@ -353,43 +326,55 @@ static bool expandClampIntrinsic(CallInst *Orig, Intrinsic::ID ClampIntrinsic) {
   Builder.SetInsertPoint(Orig);
   auto *MaxCall = Builder.CreateIntrinsic(
       Ty, getMaxForClamp(Ty, ClampIntrinsic), {X, Min}, nullptr, "dx.max");
-  auto *MinCall =
-      Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic),
-                              {MaxCall, Max}, nullptr, "dx.min");
-
-  Orig->replaceAllUsesWith(MinCall);
-  Orig->eraseFromParent();
-  return true;
+  return Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic),
+                                 {MaxCall, Max}, nullptr, "dx.min");
 }
 
 static bool expandIntrinsic(Function &F, CallInst *Orig) {
+  Value *Result = nullptr;
   switch (F.getIntrinsicID()) {
   case Intrinsic::abs:
-    return expandAbs(Orig);
+    Result = expandAbs(Orig);
+    break;
   case Intrinsic::exp:
-    return expandExpIntrinsic(Orig);
+    Result = expandExpIntrinsic(Orig);
+    break;
   case Intrinsic::log:
-    return expandLogIntrinsic(Orig);
+    Result = expandLogIntrinsic(Orig);
+    break;
   case Intrinsic::log10:
-    return expandLog10Intrinsic(Orig);
+    Result = expandLog10Intrinsic(Orig);
+    break;
   case Intrinsic::pow:
-    return expandPowIntrinsic(Orig);
+    Result = expandPowIntrinsic(Orig);
+    break;
   case Intrinsic::dx_any:
-    return expandAnyIntrinsic(Orig);
+    Result = expandAnyIntrinsic(Orig);
+    break;
   case Intrinsic::dx_uclamp:
   case Intrinsic::dx_clamp:
-    return expandClampIntrinsic(Orig, F.getIntrinsicID());
+    Result = expandClampIntrinsic(Orig, F.getIntrinsicID());
+    break;
   case Intrinsic::dx_lerp:
-    return expandLerpIntrinsic(Orig);
+    Result = expandLerpIntrinsic(Orig);
+    break;
   case Intrinsic::dx_length:
-    return expandLengthIntrinsic(Orig);
+    Result = expandLengthIntrinsic(Orig);
+    break;
   case Intrinsic::dx_normalize:
-    return expandNormalizeIntrinsic(Orig);
+    Result = expandNormalizeIntrinsic(Orig);
+    break;
   case Intrinsic::dx_sdot:
   case Intrinsic::dx_udot:
-    return expandIntegerDot(Orig, F.getIntrinsicID());
+    Result = expandIntegerDot(Orig, F.getIntrinsicID());
+    break;
   }
-  return false;
+
+  if (Result) {
+    Orig->replaceAllUsesWith(Result);
+    Orig->eraseFromParent();
+  }
+  return !!Result;
 }
 
 static bool expansionIntrinsics(Module &M) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/104626


More information about the llvm-commits mailing list