[llvm] [NFC][DXIL] move replace/erase in DXIL intrinsic expansion to caller (PR #104626)

Greg Roth via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 16 11:21:35 PDT 2024


https://github.com/pow2clk created https://github.com/llvm/llvm-project/pull/104626

All expansions end with replacing the previous inrinsic with the new expansion and erasing the old one. By moving this operation to the caller, these expansion functions can be called in more contexts and a small amount of duplicated code is consolidated.

Pre-req for #88056

>From b24bf27c30b957e9a041c538fee882b932c9b360 Mon Sep 17 00:00:00 2001
From: Greg Roth <grroth at microsoft.com>
Date: Wed, 14 Aug 2024 12:42:33 -0600
Subject: [PATCH] [NFC][DXIL] move replace/erase in DXIL intrinsic expansion to
 caller

All expansions end with replacing the previous inrinsic with the new
expansion and erasing the old one. By moving this operation to the
caller, these expansion functions can be called in more contexts
and a small amount of duplicated code is consolidated.

Pre-req for #88056
---
 .../Target/DirectX/DXILIntrinsicExpansion.cpp | 135 ++++++++----------
 1 file changed, 60 insertions(+), 75 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index e63633b8a1e1ab..a4cb1c3d575473 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -51,7 +51,7 @@ static bool isIntrinsicExpansion(Function &F) {
   return false;
 }
 
-static bool expandAbs(CallInst *Orig) {
+static Value *expandAbs(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
   IRBuilder<> Builder(Orig->getParent());
   Builder.SetInsertPoint(Orig);
@@ -66,12 +66,10 @@ static bool expandAbs(CallInst *Orig) {
   auto *V = Builder.CreateSub(Zero, X);
   auto *MaxCall =
       Builder.CreateIntrinsic(Ty, Intrinsic::smax, {X, V}, nullptr, "dx.max");
-  Orig->replaceAllUsesWith(MaxCall);
-  Orig->eraseFromParent();
-  return true;
+  return MaxCall;
 }
 
-static bool expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
+static Value *expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
   assert(DotIntrinsic == Intrinsic::dx_sdot ||
          DotIntrinsic == Intrinsic::dx_udot);
   Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot
@@ -97,12 +95,10 @@ static bool expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
                                      ArrayRef<Value *>{Elt0, Elt1, Result},
                                      nullptr, "dx.mad");
   }
-  Orig->replaceAllUsesWith(Result);
-  Orig->eraseFromParent();
-  return true;
+  return Result;
 }
 
-static bool expandExpIntrinsic(CallInst *Orig) {
+static Value *expandExpIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
   IRBuilder<> Builder(Orig->getParent());
   Builder.SetInsertPoint(Orig);
@@ -119,23 +115,21 @@ static bool expandExpIntrinsic(CallInst *Orig) {
       Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {NewX}, nullptr, "dx.exp2");
   Exp2Call->setTailCall(Orig->isTailCall());
   Exp2Call->setAttributes(Orig->getAttributes());
-  Orig->replaceAllUsesWith(Exp2Call);
-  Orig->eraseFromParent();
-  return true;
+  return Exp2Call;
 }
 
-static bool expandAnyIntrinsic(CallInst *Orig) {
+static Value *expandAnyIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
   IRBuilder<> Builder(Orig->getParent());
   Builder.SetInsertPoint(Orig);
   Type *Ty = X->getType();
   Type *EltTy = Ty->getScalarType();
 
+  Value *Result = nullptr;
   if (!Ty->isVectorTy()) {
-    Value *Cond = EltTy->isFloatingPointTy()
-                      ? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0))
-                      : Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0));
-    Orig->replaceAllUsesWith(Cond);
+    Result = EltTy->isFloatingPointTy()
+                 ? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0))
+                 : Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0));
   } else {
     auto *XVec = dyn_cast<FixedVectorType>(Ty);
     Value *Cond =
@@ -148,18 +142,16 @@ static bool expandAnyIntrinsic(CallInst *Orig) {
                   X, ConstantVector::getSplat(
                          ElementCount::getFixed(XVec->getNumElements()),
                          ConstantInt::get(EltTy, 0)));
-    Value *Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
+    Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
     for (unsigned I = 1; I < XVec->getNumElements(); I++) {
       Value *Elt = Builder.CreateExtractElement(Cond, I);
       Result = Builder.CreateOr(Result, Elt);
     }
-    Orig->replaceAllUsesWith(Result);
   }
-  Orig->eraseFromParent();
-  return true;
+  return Result;
 }
 
-static bool expandLengthIntrinsic(CallInst *Orig) {
+static Value *expandLengthIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
   IRBuilder<> Builder(Orig->getParent());
   Builder.SetInsertPoint(Orig);
@@ -182,15 +174,11 @@ static bool expandLengthIntrinsic(CallInst *Orig) {
     Value *Mul = Builder.CreateFMul(Elt, Elt);
     Sum = Builder.CreateFAdd(Sum, Mul);
   }
-  Value *Result = Builder.CreateIntrinsic(
-      EltTy, Intrinsic::sqrt, ArrayRef<Value *>{Sum}, nullptr, "elt.sqrt");
-
-  Orig->replaceAllUsesWith(Result);
-  Orig->eraseFromParent();
-  return true;
+  return Builder.CreateIntrinsic(EltTy, Intrinsic::sqrt, ArrayRef<Value *>{Sum},
+                                 nullptr, "elt.sqrt");
 }
 
-static bool expandLerpIntrinsic(CallInst *Orig) {
+static Value *expandLerpIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
   Value *Y = Orig->getOperand(1);
   Value *S = Orig->getOperand(2);
@@ -198,14 +186,11 @@ static bool expandLerpIntrinsic(CallInst *Orig) {
   Builder.SetInsertPoint(Orig);
   auto *V = Builder.CreateFSub(Y, X);
   V = Builder.CreateFMul(S, V);
-  auto *Result = Builder.CreateFAdd(X, V, "dx.lerp");
-  Orig->replaceAllUsesWith(Result);
-  Orig->eraseFromParent();
-  return true;
+  return Builder.CreateFAdd(X, V, "dx.lerp");
 }
 
-static bool expandLogIntrinsic(CallInst *Orig,
-                               float LogConstVal = numbers::ln2f) {
+static Value *expandLogIntrinsic(CallInst *Orig,
+                                 float LogConstVal = numbers::ln2f) {
   Value *X = Orig->getOperand(0);
   IRBuilder<> Builder(Orig->getParent());
   Builder.SetInsertPoint(Orig);
@@ -221,16 +206,13 @@ static bool expandLogIntrinsic(CallInst *Orig,
       Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
   Log2Call->setTailCall(Orig->isTailCall());
   Log2Call->setAttributes(Orig->getAttributes());
-  auto *Result = Builder.CreateFMul(Ln2Const, Log2Call);
-  Orig->replaceAllUsesWith(Result);
-  Orig->eraseFromParent();
-  return true;
+  return Builder.CreateFMul(Ln2Const, Log2Call);
 }
-static bool expandLog10Intrinsic(CallInst *Orig) {
+static Value *expandLog10Intrinsic(CallInst *Orig) {
   return expandLogIntrinsic(Orig, numbers::ln2f / numbers::ln10f);
 }
 
-static bool expandNormalizeIntrinsic(CallInst *Orig) {
+static Value *expandNormalizeIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
   Type *Ty = Orig->getType();
   Type *EltTy = Ty->getScalarType();
@@ -245,11 +227,7 @@ static bool expandNormalizeIntrinsic(CallInst *Orig) {
         report_fatal_error(Twine("Invalid input scalar: length is zero"),
                            /* gen_crash_diag=*/false);
     }
-    Value *Result = Builder.CreateFDiv(X, X);
-
-    Orig->replaceAllUsesWith(Result);
-    Orig->eraseFromParent();
-    return true;
+    return Builder.CreateFDiv(X, X);
   }
 
   unsigned XVecSize = XVec->getNumElements();
@@ -291,14 +269,10 @@ static bool expandNormalizeIntrinsic(CallInst *Orig) {
                                                 nullptr, "dx.rsqrt");
 
   Value *MultiplicandVec = Builder.CreateVectorSplat(XVecSize, Multiplicand);
-  Value *Result = Builder.CreateFMul(X, MultiplicandVec);
-
-  Orig->replaceAllUsesWith(Result);
-  Orig->eraseFromParent();
-  return true;
+  return Builder.CreateFMul(X, MultiplicandVec);
 }
 
-static bool expandPowIntrinsic(CallInst *Orig) {
+static Value *expandPowIntrinsic(CallInst *Orig) {
 
   Value *X = Orig->getOperand(0);
   Value *Y = Orig->getOperand(1);
@@ -313,9 +287,7 @@ static bool expandPowIntrinsic(CallInst *Orig) {
       Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {Mul}, nullptr, "elt.exp2");
   Exp2Call->setTailCall(Orig->isTailCall());
   Exp2Call->setAttributes(Orig->getAttributes());
-  Orig->replaceAllUsesWith(Exp2Call);
-  Orig->eraseFromParent();
-  return true;
+  return Exp2Call;
 }
 
 static Intrinsic::ID getMaxForClamp(Type *ElemTy,
@@ -344,7 +316,8 @@ static Intrinsic::ID getMinForClamp(Type *ElemTy,
   return Intrinsic::minnum;
 }
 
-static bool expandClampIntrinsic(CallInst *Orig, Intrinsic::ID ClampIntrinsic) {
+static Value *expandClampIntrinsic(CallInst *Orig,
+                                   Intrinsic::ID ClampIntrinsic) {
   Value *X = Orig->getOperand(0);
   Value *Min = Orig->getOperand(1);
   Value *Max = Orig->getOperand(2);
@@ -353,43 +326,55 @@ static bool expandClampIntrinsic(CallInst *Orig, Intrinsic::ID ClampIntrinsic) {
   Builder.SetInsertPoint(Orig);
   auto *MaxCall = Builder.CreateIntrinsic(
       Ty, getMaxForClamp(Ty, ClampIntrinsic), {X, Min}, nullptr, "dx.max");
-  auto *MinCall =
-      Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic),
-                              {MaxCall, Max}, nullptr, "dx.min");
-
-  Orig->replaceAllUsesWith(MinCall);
-  Orig->eraseFromParent();
-  return true;
+  return Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic),
+                                 {MaxCall, Max}, nullptr, "dx.min");
 }
 
 static bool expandIntrinsic(Function &F, CallInst *Orig) {
+  Value *Result = nullptr;
   switch (F.getIntrinsicID()) {
   case Intrinsic::abs:
-    return expandAbs(Orig);
+    Result = expandAbs(Orig);
+    break;
   case Intrinsic::exp:
-    return expandExpIntrinsic(Orig);
+    Result = expandExpIntrinsic(Orig);
+    break;
   case Intrinsic::log:
-    return expandLogIntrinsic(Orig);
+    Result = expandLogIntrinsic(Orig);
+    break;
   case Intrinsic::log10:
-    return expandLog10Intrinsic(Orig);
+    Result = expandLog10Intrinsic(Orig);
+    break;
   case Intrinsic::pow:
-    return expandPowIntrinsic(Orig);
+    Result = expandPowIntrinsic(Orig);
+    break;
   case Intrinsic::dx_any:
-    return expandAnyIntrinsic(Orig);
+    Result = expandAnyIntrinsic(Orig);
+    break;
   case Intrinsic::dx_uclamp:
   case Intrinsic::dx_clamp:
-    return expandClampIntrinsic(Orig, F.getIntrinsicID());
+    Result = expandClampIntrinsic(Orig, F.getIntrinsicID());
+    break;
   case Intrinsic::dx_lerp:
-    return expandLerpIntrinsic(Orig);
+    Result = expandLerpIntrinsic(Orig);
+    break;
   case Intrinsic::dx_length:
-    return expandLengthIntrinsic(Orig);
+    Result = expandLengthIntrinsic(Orig);
+    break;
   case Intrinsic::dx_normalize:
-    return expandNormalizeIntrinsic(Orig);
+    Result = expandNormalizeIntrinsic(Orig);
+    break;
   case Intrinsic::dx_sdot:
   case Intrinsic::dx_udot:
-    return expandIntegerDot(Orig, F.getIntrinsicID());
+    Result = expandIntegerDot(Orig, F.getIntrinsicID());
+    break;
   }
-  return false;
+
+  if (Result) {
+    Orig->replaceAllUsesWith(Result);
+    Orig->eraseFromParent();
+  }
+  return !!Result;
 }
 
 static bool expansionIntrinsics(Module &M) {



More information about the llvm-commits mailing list