[llvm] [DirectX] legalize powi (PR #135228)

via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 10 11:18:51 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-directx

Author: Farzon Lotfi (farzonl)

<details>
<summary>Changes</summary>

fixes #<!-- -->135221
- have powi use the same legalization path as pow
- use CreateSIToFP to cast the int back to a float type
- add tests for powi

---
Full diff: https://github.com/llvm/llvm-project/pull/135228.diff


2 Files Affected:

- (modified) llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp (+9-4) 
- (modified) llvm/test/CodeGen/DirectX/pow.ll (+26) 


``````````diff
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index e44d3b70eb657..84acf4d536d0c 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -49,6 +49,7 @@ static bool isIntrinsicExpansion(Function &F) {
   case Intrinsic::log:
   case Intrinsic::log10:
   case Intrinsic::pow:
+  case Intrinsic::powi:
   case Intrinsic::dx_all:
   case Intrinsic::dx_any:
   case Intrinsic::dx_cross:
@@ -251,7 +252,7 @@ static Value *expandExpIntrinsic(CallInst *Orig) {
 }
 
 static Value *expandAnyOrAllIntrinsic(CallInst *Orig,
-                                      Intrinsic::ID intrinsicId) {
+                                      Intrinsic::ID IntrinsicId) {
   Value *X = Orig->getOperand(0);
   IRBuilder<> Builder(Orig);
   Type *Ty = X->getType();
@@ -285,7 +286,7 @@ static Value *expandAnyOrAllIntrinsic(CallInst *Orig,
     Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
     for (unsigned I = 1; I < XVec->getNumElements(); I++) {
       Value *Elt = Builder.CreateExtractElement(Cond, I);
-      Result = ApplyOp(intrinsicId, Result, Elt);
+      Result = ApplyOp(IntrinsicId, Result, Elt);
     }
   }
   return Result;
@@ -410,13 +411,16 @@ static Value *expandAtan2Intrinsic(CallInst *Orig) {
   return Result;
 }
 
-static Value *expandPowIntrinsic(CallInst *Orig) {
+static Value *expandPowIntrinsic(CallInst *Orig, Intrinsic::ID IntrinsicId) {
 
   Value *X = Orig->getOperand(0);
   Value *Y = Orig->getOperand(1);
   Type *Ty = X->getType();
   IRBuilder<> Builder(Orig);
 
+  if (IntrinsicId == Intrinsic::powi)
+    Y = Builder.CreateSIToFP(Y, Ty);
+
   auto *Log2Call =
       Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
   auto *Mul = Builder.CreateFMul(Log2Call, Y);
@@ -542,7 +546,8 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
     Result = expandLog10Intrinsic(Orig);
     break;
   case Intrinsic::pow:
-    Result = expandPowIntrinsic(Orig);
+  case Intrinsic::powi:
+    Result = expandPowIntrinsic(Orig, IntrinsicId);
     break;
   case Intrinsic::dx_all:
   case Intrinsic::dx_any:
diff --git a/llvm/test/CodeGen/DirectX/pow.ll b/llvm/test/CodeGen/DirectX/pow.ll
index 378649ec119d5..ebc0382415801 100644
--- a/llvm/test/CodeGen/DirectX/pow.ll
+++ b/llvm/test/CodeGen/DirectX/pow.ll
@@ -25,5 +25,31 @@ entry:
   ret half %elt.pow
 }
 
+define noundef float @powi_float(float noundef %a, i32 noundef %b) {
+entry:
+; CHECK: [[CAST:%.*]] = sitofp i32 %b to float
+; DOPCHECK: call float @dx.op.unary.f32(i32 23, float %a)
+; EXPCHECK: call float @llvm.log2.f32(float %a)
+; CHECK: fmul float %{{.*}}, [[CAST]]
+; DOPCHECK: call float @dx.op.unary.f32(i32 21, float %{{.*}})
+; EXPCHECK: call float @llvm.exp2.f32(float %{{.*}})
+  %elt.powi = call float @llvm.powi.f32.i32(float %a, i32 %b)
+  ret float %elt.powi
+}
+
+define noundef half @powi_half(half noundef %a, i32 noundef %b) {
+entry:
+; CHECK: [[CAST:%.*]] = sitofp i32 %b to half
+; DOPCHECK: call half @dx.op.unary.f16(i32 23, half %a)
+; EXPCHECK: call half @llvm.log2.f16(half %a)
+; CHECK: fmul half %{{.*}}, [[CAST]]
+; DOPCHECK: call half @dx.op.unary.f16(i32 21, half %{{.*}})
+; EXPCHECK: call half @llvm.exp2.f16(half %{{.*}})
+  %elt.powi = call half @llvm.powi.f16.i32(half %a, i32 %b)
+  ret half %elt.powi
+}
+
 declare half @llvm.pow.f16(half,half)
 declare float @llvm.pow.f32(float,float)
+declare half @llvm.powi.f16.i32(half,i32)
+declare float @llvm.powi.f32.i32(float,i32)

``````````

</details>


https://github.com/llvm/llvm-project/pull/135228


More information about the llvm-commits mailing list