[llvm] ff5816a - [DirectX] Add `all` lowering (#105787)

via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 26 10:40:15 PDT 2024


Author: Farzon Lotfi
Date: 2024-08-26T13:40:11-04:00
New Revision: ff5816ad29eba3762e1c5c576c1adf586c35dd91

URL: https://github.com/llvm/llvm-project/commit/ff5816ad29eba3762e1c5c576c1adf586c35dd91
DIFF: https://github.com/llvm/llvm-project/commit/ff5816ad29eba3762e1c5c576c1adf586c35dd91.diff

LOG: [DirectX] Add `all` lowering (#105787)

- DXILIntrinsicExpansion.cpp: Modify `any` codegen expansion to work for
`all`
- DirectX\all.ll: Add test case

completes #88946

Added: 
    llvm/test/CodeGen/DirectX/all.ll

Modified: 
    llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index e49169cff8aa86..2daa4f825c3b25 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -38,6 +38,7 @@ static bool isIntrinsicExpansion(Function &F) {
   case Intrinsic::log:
   case Intrinsic::log10:
   case Intrinsic::pow:
+  case Intrinsic::dx_all:
   case Intrinsic::dx_any:
   case Intrinsic::dx_clamp:
   case Intrinsic::dx_uclamp:
@@ -54,8 +55,7 @@ static bool isIntrinsicExpansion(Function &F) {
 
 static Value *expandAbs(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
-  IRBuilder<> Builder(Orig->getParent());
-  Builder.SetInsertPoint(Orig);
+  IRBuilder<> Builder(Orig);
   Type *Ty = X->getType();
   Type *EltTy = Ty->getScalarType();
   Constant *Zero = Ty->isVectorTy()
@@ -148,8 +148,7 @@ static Value *expandIntegerDotIntrinsic(CallInst *Orig,
 
 static Value *expandExpIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
-  IRBuilder<> Builder(Orig->getParent());
-  Builder.SetInsertPoint(Orig);
+  IRBuilder<> Builder(Orig);
   Type *Ty = X->getType();
   Type *EltTy = Ty->getScalarType();
   Constant *Log2eConst =
@@ -166,13 +165,21 @@ static Value *expandExpIntrinsic(CallInst *Orig) {
   return Exp2Call;
 }
 
-static Value *expandAnyIntrinsic(CallInst *Orig) {
+static Value *expandAnyOrAllIntrinsic(CallInst *Orig,
+                                      Intrinsic::ID intrinsicId) {
   Value *X = Orig->getOperand(0);
-  IRBuilder<> Builder(Orig->getParent());
-  Builder.SetInsertPoint(Orig);
+  IRBuilder<> Builder(Orig);
   Type *Ty = X->getType();
   Type *EltTy = Ty->getScalarType();
 
+  auto ApplyOp = [&Builder](Intrinsic::ID IntrinsicId, Value *Result,
+                            Value *Elt) {
+    if (IntrinsicId == Intrinsic::dx_any)
+      return Builder.CreateOr(Result, Elt);
+    assert(IntrinsicId == Intrinsic::dx_all);
+    return Builder.CreateAnd(Result, Elt);
+  };
+
   Value *Result = nullptr;
   if (!Ty->isVectorTy()) {
     Result = EltTy->isFloatingPointTy()
@@ -193,7 +200,7 @@ static Value *expandAnyIntrinsic(CallInst *Orig) {
     Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
     for (unsigned I = 1; I < XVec->getNumElements(); I++) {
       Value *Elt = Builder.CreateExtractElement(Cond, I);
-      Result = Builder.CreateOr(Result, Elt);
+      Result = ApplyOp(intrinsicId, Result, Elt);
     }
   }
   return Result;
@@ -201,8 +208,7 @@ static Value *expandAnyIntrinsic(CallInst *Orig) {
 
 static Value *expandLengthIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
-  IRBuilder<> Builder(Orig->getParent());
-  Builder.SetInsertPoint(Orig);
+  IRBuilder<> Builder(Orig);
   Type *Ty = X->getType();
   Type *EltTy = Ty->getScalarType();
 
@@ -230,8 +236,7 @@ static Value *expandLerpIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
   Value *Y = Orig->getOperand(1);
   Value *S = Orig->getOperand(2);
-  IRBuilder<> Builder(Orig->getParent());
-  Builder.SetInsertPoint(Orig);
+  IRBuilder<> Builder(Orig);
   auto *V = Builder.CreateFSub(Y, X);
   V = Builder.CreateFMul(S, V);
   return Builder.CreateFAdd(X, V, "dx.lerp");
@@ -240,8 +245,7 @@ static Value *expandLerpIntrinsic(CallInst *Orig) {
 static Value *expandLogIntrinsic(CallInst *Orig,
                                  float LogConstVal = numbers::ln2f) {
   Value *X = Orig->getOperand(0);
-  IRBuilder<> Builder(Orig->getParent());
-  Builder.SetInsertPoint(Orig);
+  IRBuilder<> Builder(Orig);
   Type *Ty = X->getType();
   Type *EltTy = Ty->getScalarType();
   Constant *Ln2Const =
@@ -266,8 +270,7 @@ static Value *expandNormalizeIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
   Type *Ty = Orig->getType();
   Type *EltTy = Ty->getScalarType();
-  IRBuilder<> Builder(Orig->getParent());
-  Builder.SetInsertPoint(Orig);
+  IRBuilder<> Builder(Orig);
 
   auto *XVec = dyn_cast<FixedVectorType>(Ty);
   if (!XVec) {
@@ -305,8 +308,7 @@ static Value *expandPowIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
   Value *Y = Orig->getOperand(1);
   Type *Ty = X->getType();
-  IRBuilder<> Builder(Orig->getParent());
-  Builder.SetInsertPoint(Orig);
+  IRBuilder<> Builder(Orig);
 
   auto *Log2Call =
       Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
@@ -350,8 +352,7 @@ static Value *expandClampIntrinsic(CallInst *Orig,
   Value *Min = Orig->getOperand(1);
   Value *Max = Orig->getOperand(2);
   Type *Ty = X->getType();
-  IRBuilder<> Builder(Orig->getParent());
-  Builder.SetInsertPoint(Orig);
+  IRBuilder<> Builder(Orig);
   auto *MaxCall = Builder.CreateIntrinsic(
       Ty, getMaxForClamp(Ty, ClampIntrinsic), {X, Min}, nullptr, "dx.max");
   return Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic),
@@ -360,7 +361,8 @@ static Value *expandClampIntrinsic(CallInst *Orig,
 
 static bool expandIntrinsic(Function &F, CallInst *Orig) {
   Value *Result = nullptr;
-  switch (F.getIntrinsicID()) {
+  Intrinsic::ID IntrinsicId = F.getIntrinsicID();
+  switch (IntrinsicId) {
   case Intrinsic::abs:
     Result = expandAbs(Orig);
     break;
@@ -376,12 +378,13 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
   case Intrinsic::pow:
     Result = expandPowIntrinsic(Orig);
     break;
+  case Intrinsic::dx_all:
   case Intrinsic::dx_any:
-    Result = expandAnyIntrinsic(Orig);
+    Result = expandAnyOrAllIntrinsic(Orig, IntrinsicId);
     break;
   case Intrinsic::dx_uclamp:
   case Intrinsic::dx_clamp:
-    Result = expandClampIntrinsic(Orig, F.getIntrinsicID());
+    Result = expandClampIntrinsic(Orig, IntrinsicId);
     break;
   case Intrinsic::dx_lerp:
     Result = expandLerpIntrinsic(Orig);
@@ -397,7 +400,7 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
     break;
   case Intrinsic::dx_sdot:
   case Intrinsic::dx_udot:
-    Result = expandIntegerDotIntrinsic(Orig, F.getIntrinsicID());
+    Result = expandIntegerDotIntrinsic(Orig, IntrinsicId);
     break;
   }
 

diff  --git a/llvm/test/CodeGen/DirectX/all.ll b/llvm/test/CodeGen/DirectX/all.ll
new file mode 100644
index 00000000000000..1c0b6486dc9358
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/all.ll
@@ -0,0 +1,83 @@
+; RUN: opt -S -passes=dxil-intrinsic-expansion,dxil-op-lower -mtriple=dxil-pc-shadermodel6.0-library < %s | FileCheck %s
+
+; Make sure dxil operation function calls for all are generated for float and half.
+
+; CHECK-LABEL: all_bool
+; CHECK: icmp ne i1 %{{.*}}, false
+define noundef i1 @all_bool(i1 noundef %p0) {
+entry:
+  %dx.all = call i1 @llvm.dx.all.i1(i1 %p0)
+  ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_int64_t
+; CHECK: icmp ne i64 %{{.*}}, 0
+define noundef i1 @all_int64_t(i64 noundef %p0) {
+entry:
+  %dx.all = call i1 @llvm.dx.all.i64(i64 %p0)
+  ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_int
+; CHECK: icmp ne i32 %{{.*}}, 0
+define noundef i1 @all_int(i32 noundef %p0) {
+entry:
+  %dx.all = call i1 @llvm.dx.all.i32(i32 %p0)
+  ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_int16_t
+; CHECK: icmp ne i16 %{{.*}}, 0
+define noundef i1 @all_int16_t(i16 noundef %p0) {
+entry:
+  %dx.all = call i1 @llvm.dx.all.i16(i16 %p0)
+  ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_double
+; CHECK: fcmp une double %{{.*}}, 0.000000e+00
+define noundef i1 @all_double(double noundef %p0) {
+entry:
+  %dx.all = call i1 @llvm.dx.all.f64(double %p0)
+  ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_float
+; CHECK: fcmp une float %{{.*}}, 0.000000e+00
+define noundef i1 @all_float(float noundef %p0) {
+entry:
+  %dx.all = call i1 @llvm.dx.all.f32(float %p0)
+  ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_half
+; CHECK: fcmp une half %{{.*}}, 0xH0000
+define noundef i1 @all_half(half noundef %p0) {
+entry:
+  %dx.all = call i1 @llvm.dx.all.f16(half %p0)
+  ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_bool4
+; CHECK: icmp ne <4 x i1> %{{.*}}, zeroinitialize
+; CHECK: extractelement <4 x i1> %{{.*}}, i64 0
+; CHECK: extractelement <4 x i1> %{{.*}}, i64 1
+; CHECK: and i1  %{{.*}}, %{{.*}}
+; CHECK: extractelement <4 x i1> %{{.*}}, i64 2
+; CHECK: and i1  %{{.*}}, %{{.*}}
+; CHECK: extractelement <4 x i1> %{{.*}}, i64 3
+; CHECK: and i1  %{{.*}}, %{{.*}}
+define noundef i1 @all_bool4(<4 x i1> noundef %p0) {
+entry:
+  %dx.all = call i1 @llvm.dx.all.v4i1(<4 x i1> %p0)
+  ret i1 %dx.all
+}
+
+declare i1 @llvm.dx.all.v4i1(<4 x i1>)
+declare i1 @llvm.dx.all.i1(i1)
+declare i1 @llvm.dx.all.i16(i16)
+declare i1 @llvm.dx.all.i32(i32)
+declare i1 @llvm.dx.all.i64(i64)
+declare i1 @llvm.dx.all.f16(half)
+declare i1 @llvm.dx.all.f32(float)
+declare i1 @llvm.dx.all.f64(double)


        


More information about the llvm-commits mailing list