[llvm] cd3f48d - [NFC][DXIL] move replace/erase in DXIL intrinsic expansion to caller (#104626)

via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 16 17:11:21 PDT 2024


Author: Greg Roth
Date: 2024-08-16T17:11:16-07:00
New Revision: cd3f48df82dae19493da838afd1c0aaf98c737eb

URL: https://github.com/llvm/llvm-project/commit/cd3f48df82dae19493da838afd1c0aaf98c737eb
DIFF: https://github.com/llvm/llvm-project/commit/cd3f48df82dae19493da838afd1c0aaf98c737eb.diff

LOG: [NFC][DXIL] move replace/erase in DXIL intrinsic expansion to caller (#104626)

All expansions end with replacing the previous inrinsic with the new
expansion and erasing the old one. By moving this operation to the
caller, these expansion functions can be called in more contexts and a
small amount of duplicated code is consolidated.

Pre-req for #88056

Added: 
    

Modified: 
    llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index e63633b8a1e1a..2c481d15be5bd 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -51,7 +51,7 @@ static bool isIntrinsicExpansion(Function &F) {
   return false;
 }
 
-static bool expandAbs(CallInst *Orig) {
+static Value *expandAbs(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
   IRBuilder<> Builder(Orig->getParent());
   Builder.SetInsertPoint(Orig);
@@ -64,14 +64,11 @@ static bool expandAbs(CallInst *Orig) {
                              ConstantInt::get(EltTy, 0))
                        : ConstantInt::get(EltTy, 0);
   auto *V = Builder.CreateSub(Zero, X);
-  auto *MaxCall =
-      Builder.CreateIntrinsic(Ty, Intrinsic::smax, {X, V}, nullptr, "dx.max");
-  Orig->replaceAllUsesWith(MaxCall);
-  Orig->eraseFromParent();
-  return true;
+  return Builder.CreateIntrinsic(Ty, Intrinsic::smax, {X, V}, nullptr,
+                                 "dx.max");
 }
 
-static bool expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
+static Value *expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
   assert(DotIntrinsic == Intrinsic::dx_sdot ||
          DotIntrinsic == Intrinsic::dx_udot);
   Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot
@@ -97,12 +94,10 @@ static bool expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
                                      ArrayRef<Value *>{Elt0, Elt1, Result},
                                      nullptr, "dx.mad");
   }
-  Orig->replaceAllUsesWith(Result);
-  Orig->eraseFromParent();
-  return true;
+  return Result;
 }
 
-static bool expandExpIntrinsic(CallInst *Orig) {
+static Value *expandExpIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
   IRBuilder<> Builder(Orig->getParent());
   Builder.SetInsertPoint(Orig);
@@ -119,23 +114,21 @@ static bool expandExpIntrinsic(CallInst *Orig) {
       Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {NewX}, nullptr, "dx.exp2");
   Exp2Call->setTailCall(Orig->isTailCall());
   Exp2Call->setAttributes(Orig->getAttributes());
-  Orig->replaceAllUsesWith(Exp2Call);
-  Orig->eraseFromParent();
-  return true;
+  return Exp2Call;
 }
 
-static bool expandAnyIntrinsic(CallInst *Orig) {
+static Value *expandAnyIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
   IRBuilder<> Builder(Orig->getParent());
   Builder.SetInsertPoint(Orig);
   Type *Ty = X->getType();
   Type *EltTy = Ty->getScalarType();
 
+  Value *Result = nullptr;
   if (!Ty->isVectorTy()) {
-    Value *Cond = EltTy->isFloatingPointTy()
-                      ? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0))
-                      : Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0));
-    Orig->replaceAllUsesWith(Cond);
+    Result = EltTy->isFloatingPointTy()
+                 ? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0))
+                 : Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0));
   } else {
     auto *XVec = dyn_cast<FixedVectorType>(Ty);
     Value *Cond =
@@ -148,18 +141,16 @@ static bool expandAnyIntrinsic(CallInst *Orig) {
                   X, ConstantVector::getSplat(
                          ElementCount::getFixed(XVec->getNumElements()),
                          ConstantInt::get(EltTy, 0)));
-    Value *Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
+    Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
     for (unsigned I = 1; I < XVec->getNumElements(); I++) {
       Value *Elt = Builder.CreateExtractElement(Cond, I);
       Result = Builder.CreateOr(Result, Elt);
     }
-    Orig->replaceAllUsesWith(Result);
   }
-  Orig->eraseFromParent();
-  return true;
+  return Result;
 }
 
-static bool expandLengthIntrinsic(CallInst *Orig) {
+static Value *expandLengthIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
   IRBuilder<> Builder(Orig->getParent());
   Builder.SetInsertPoint(Orig);
@@ -182,15 +173,11 @@ static bool expandLengthIntrinsic(CallInst *Orig) {
     Value *Mul = Builder.CreateFMul(Elt, Elt);
     Sum = Builder.CreateFAdd(Sum, Mul);
   }
-  Value *Result = Builder.CreateIntrinsic(
-      EltTy, Intrinsic::sqrt, ArrayRef<Value *>{Sum}, nullptr, "elt.sqrt");
-
-  Orig->replaceAllUsesWith(Result);
-  Orig->eraseFromParent();
-  return true;
+  return Builder.CreateIntrinsic(EltTy, Intrinsic::sqrt, ArrayRef<Value *>{Sum},
+                                 nullptr, "elt.sqrt");
 }
 
-static bool expandLerpIntrinsic(CallInst *Orig) {
+static Value *expandLerpIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
   Value *Y = Orig->getOperand(1);
   Value *S = Orig->getOperand(2);
@@ -198,14 +185,11 @@ static bool expandLerpIntrinsic(CallInst *Orig) {
   Builder.SetInsertPoint(Orig);
   auto *V = Builder.CreateFSub(Y, X);
   V = Builder.CreateFMul(S, V);
-  auto *Result = Builder.CreateFAdd(X, V, "dx.lerp");
-  Orig->replaceAllUsesWith(Result);
-  Orig->eraseFromParent();
-  return true;
+  return Builder.CreateFAdd(X, V, "dx.lerp");
 }
 
-static bool expandLogIntrinsic(CallInst *Orig,
-                               float LogConstVal = numbers::ln2f) {
+static Value *expandLogIntrinsic(CallInst *Orig,
+                                 float LogConstVal = numbers::ln2f) {
   Value *X = Orig->getOperand(0);
   IRBuilder<> Builder(Orig->getParent());
   Builder.SetInsertPoint(Orig);
@@ -221,16 +205,13 @@ static bool expandLogIntrinsic(CallInst *Orig,
       Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
   Log2Call->setTailCall(Orig->isTailCall());
   Log2Call->setAttributes(Orig->getAttributes());
-  auto *Result = Builder.CreateFMul(Ln2Const, Log2Call);
-  Orig->replaceAllUsesWith(Result);
-  Orig->eraseFromParent();
-  return true;
+  return Builder.CreateFMul(Ln2Const, Log2Call);
 }
-static bool expandLog10Intrinsic(CallInst *Orig) {
+static Value *expandLog10Intrinsic(CallInst *Orig) {
   return expandLogIntrinsic(Orig, numbers::ln2f / numbers::ln10f);
 }
 
-static bool expandNormalizeIntrinsic(CallInst *Orig) {
+static Value *expandNormalizeIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
   Type *Ty = Orig->getType();
   Type *EltTy = Ty->getScalarType();
@@ -245,11 +226,7 @@ static bool expandNormalizeIntrinsic(CallInst *Orig) {
         report_fatal_error(Twine("Invalid input scalar: length is zero"),
                            /* gen_crash_diag=*/false);
     }
-    Value *Result = Builder.CreateFDiv(X, X);
-
-    Orig->replaceAllUsesWith(Result);
-    Orig->eraseFromParent();
-    return true;
+    return Builder.CreateFDiv(X, X);
   }
 
   unsigned XVecSize = XVec->getNumElements();
@@ -291,14 +268,10 @@ static bool expandNormalizeIntrinsic(CallInst *Orig) {
                                                 nullptr, "dx.rsqrt");
 
   Value *MultiplicandVec = Builder.CreateVectorSplat(XVecSize, Multiplicand);
-  Value *Result = Builder.CreateFMul(X, MultiplicandVec);
-
-  Orig->replaceAllUsesWith(Result);
-  Orig->eraseFromParent();
-  return true;
+  return Builder.CreateFMul(X, MultiplicandVec);
 }
 
-static bool expandPowIntrinsic(CallInst *Orig) {
+static Value *expandPowIntrinsic(CallInst *Orig) {
 
   Value *X = Orig->getOperand(0);
   Value *Y = Orig->getOperand(1);
@@ -313,9 +286,7 @@ static bool expandPowIntrinsic(CallInst *Orig) {
       Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {Mul}, nullptr, "elt.exp2");
   Exp2Call->setTailCall(Orig->isTailCall());
   Exp2Call->setAttributes(Orig->getAttributes());
-  Orig->replaceAllUsesWith(Exp2Call);
-  Orig->eraseFromParent();
-  return true;
+  return Exp2Call;
 }
 
 static Intrinsic::ID getMaxForClamp(Type *ElemTy,
@@ -344,7 +315,8 @@ static Intrinsic::ID getMinForClamp(Type *ElemTy,
   return Intrinsic::minnum;
 }
 
-static bool expandClampIntrinsic(CallInst *Orig, Intrinsic::ID ClampIntrinsic) {
+static Value *expandClampIntrinsic(CallInst *Orig,
+                                   Intrinsic::ID ClampIntrinsic) {
   Value *X = Orig->getOperand(0);
   Value *Min = Orig->getOperand(1);
   Value *Max = Orig->getOperand(2);
@@ -353,41 +325,54 @@ static bool expandClampIntrinsic(CallInst *Orig, Intrinsic::ID ClampIntrinsic) {
   Builder.SetInsertPoint(Orig);
   auto *MaxCall = Builder.CreateIntrinsic(
       Ty, getMaxForClamp(Ty, ClampIntrinsic), {X, Min}, nullptr, "dx.max");
-  auto *MinCall =
-      Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic),
-                              {MaxCall, Max}, nullptr, "dx.min");
-
-  Orig->replaceAllUsesWith(MinCall);
-  Orig->eraseFromParent();
-  return true;
+  return Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic),
+                                 {MaxCall, Max}, nullptr, "dx.min");
 }
 
 static bool expandIntrinsic(Function &F, CallInst *Orig) {
+  Value *Result = nullptr;
   switch (F.getIntrinsicID()) {
   case Intrinsic::abs:
-    return expandAbs(Orig);
+    Result = expandAbs(Orig);
+    break;
   case Intrinsic::exp:
-    return expandExpIntrinsic(Orig);
+    Result = expandExpIntrinsic(Orig);
+    break;
   case Intrinsic::log:
-    return expandLogIntrinsic(Orig);
+    Result = expandLogIntrinsic(Orig);
+    break;
   case Intrinsic::log10:
-    return expandLog10Intrinsic(Orig);
+    Result = expandLog10Intrinsic(Orig);
+    break;
   case Intrinsic::pow:
-    return expandPowIntrinsic(Orig);
+    Result = expandPowIntrinsic(Orig);
+    break;
   case Intrinsic::dx_any:
-    return expandAnyIntrinsic(Orig);
+    Result = expandAnyIntrinsic(Orig);
+    break;
   case Intrinsic::dx_uclamp:
   case Intrinsic::dx_clamp:
-    return expandClampIntrinsic(Orig, F.getIntrinsicID());
+    Result = expandClampIntrinsic(Orig, F.getIntrinsicID());
+    break;
   case Intrinsic::dx_lerp:
-    return expandLerpIntrinsic(Orig);
+    Result = expandLerpIntrinsic(Orig);
+    break;
   case Intrinsic::dx_length:
-    return expandLengthIntrinsic(Orig);
+    Result = expandLengthIntrinsic(Orig);
+    break;
   case Intrinsic::dx_normalize:
-    return expandNormalizeIntrinsic(Orig);
+    Result = expandNormalizeIntrinsic(Orig);
+    break;
   case Intrinsic::dx_sdot:
   case Intrinsic::dx_udot:
-    return expandIntegerDot(Orig, F.getIntrinsicID());
+    Result = expandIntegerDot(Orig, F.getIntrinsicID());
+    break;
+  }
+
+  if (Result) {
+    Orig->replaceAllUsesWith(Result);
+    Orig->eraseFromParent();
+    return true;
   }
   return false;
 }


        


More information about the llvm-commits mailing list