[llvm] [DirectX] Add `all` lowering (PR #105787)

Farzon Lotfi via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 22 23:12:54 PDT 2024


https://github.com/farzonl created https://github.com/llvm/llvm-project/pull/105787

- DXILIntrinsicExpansion.cpp: Modify `any` codegen expansion to work for `all`
- DirectX\all.ll: Add test case

completes #88946 

>From 91879fe1ba93bf461d73843789dcaf0aad86946e Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzon at farzon.org>
Date: Fri, 23 Aug 2024 00:01:33 -0400
Subject: [PATCH] [DirectX] Add All lowering - DXILIntrinsicExpansion.cpp:
 Modify `any` codegen expansion to work for `all` - DirectX\all.ll: Add test
 case

---
 .../Target/DirectX/DXILIntrinsicExpansion.cpp |  51 ++++----
 llvm/test/CodeGen/DirectX/all.ll              | 113 ++++++++++++++++++
 2 files changed, 140 insertions(+), 24 deletions(-)
 create mode 100644 llvm/test/CodeGen/DirectX/all.ll

diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index e49169cff8aa86..2daa4f825c3b25 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -38,6 +38,7 @@ static bool isIntrinsicExpansion(Function &F) {
   case Intrinsic::log:
   case Intrinsic::log10:
   case Intrinsic::pow:
+  case Intrinsic::dx_all:
   case Intrinsic::dx_any:
   case Intrinsic::dx_clamp:
   case Intrinsic::dx_uclamp:
@@ -54,8 +55,7 @@ static bool isIntrinsicExpansion(Function &F) {
 
 static Value *expandAbs(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
-  IRBuilder<> Builder(Orig->getParent());
-  Builder.SetInsertPoint(Orig);
+  IRBuilder<> Builder(Orig);
   Type *Ty = X->getType();
   Type *EltTy = Ty->getScalarType();
   Constant *Zero = Ty->isVectorTy()
@@ -148,8 +148,7 @@ static Value *expandIntegerDotIntrinsic(CallInst *Orig,
 
 static Value *expandExpIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
-  IRBuilder<> Builder(Orig->getParent());
-  Builder.SetInsertPoint(Orig);
+  IRBuilder<> Builder(Orig);
   Type *Ty = X->getType();
   Type *EltTy = Ty->getScalarType();
   Constant *Log2eConst =
@@ -166,13 +165,21 @@ static Value *expandExpIntrinsic(CallInst *Orig) {
   return Exp2Call;
 }
 
-static Value *expandAnyIntrinsic(CallInst *Orig) {
+static Value *expandAnyOrAllIntrinsic(CallInst *Orig,
+                                      Intrinsic::ID intrinsicId) {
   Value *X = Orig->getOperand(0);
-  IRBuilder<> Builder(Orig->getParent());
-  Builder.SetInsertPoint(Orig);
+  IRBuilder<> Builder(Orig);
   Type *Ty = X->getType();
   Type *EltTy = Ty->getScalarType();
 
+  auto ApplyOp = [&Builder](Intrinsic::ID IntrinsicId, Value *Result,
+                            Value *Elt) {
+    if (IntrinsicId == Intrinsic::dx_any)
+      return Builder.CreateOr(Result, Elt);
+    assert(IntrinsicId == Intrinsic::dx_all);
+    return Builder.CreateAnd(Result, Elt);
+  };
+
   Value *Result = nullptr;
   if (!Ty->isVectorTy()) {
     Result = EltTy->isFloatingPointTy()
@@ -193,7 +200,7 @@ static Value *expandAnyIntrinsic(CallInst *Orig) {
     Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
     for (unsigned I = 1; I < XVec->getNumElements(); I++) {
       Value *Elt = Builder.CreateExtractElement(Cond, I);
-      Result = Builder.CreateOr(Result, Elt);
+      Result = ApplyOp(intrinsicId, Result, Elt);
     }
   }
   return Result;
@@ -201,8 +208,7 @@ static Value *expandAnyIntrinsic(CallInst *Orig) {
 
 static Value *expandLengthIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
-  IRBuilder<> Builder(Orig->getParent());
-  Builder.SetInsertPoint(Orig);
+  IRBuilder<> Builder(Orig);
   Type *Ty = X->getType();
   Type *EltTy = Ty->getScalarType();
 
@@ -230,8 +236,7 @@ static Value *expandLerpIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
   Value *Y = Orig->getOperand(1);
   Value *S = Orig->getOperand(2);
-  IRBuilder<> Builder(Orig->getParent());
-  Builder.SetInsertPoint(Orig);
+  IRBuilder<> Builder(Orig);
   auto *V = Builder.CreateFSub(Y, X);
   V = Builder.CreateFMul(S, V);
   return Builder.CreateFAdd(X, V, "dx.lerp");
@@ -240,8 +245,7 @@ static Value *expandLerpIntrinsic(CallInst *Orig) {
 static Value *expandLogIntrinsic(CallInst *Orig,
                                  float LogConstVal = numbers::ln2f) {
   Value *X = Orig->getOperand(0);
-  IRBuilder<> Builder(Orig->getParent());
-  Builder.SetInsertPoint(Orig);
+  IRBuilder<> Builder(Orig);
   Type *Ty = X->getType();
   Type *EltTy = Ty->getScalarType();
   Constant *Ln2Const =
@@ -266,8 +270,7 @@ static Value *expandNormalizeIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
   Type *Ty = Orig->getType();
   Type *EltTy = Ty->getScalarType();
-  IRBuilder<> Builder(Orig->getParent());
-  Builder.SetInsertPoint(Orig);
+  IRBuilder<> Builder(Orig);
 
   auto *XVec = dyn_cast<FixedVectorType>(Ty);
   if (!XVec) {
@@ -305,8 +308,7 @@ static Value *expandPowIntrinsic(CallInst *Orig) {
   Value *X = Orig->getOperand(0);
   Value *Y = Orig->getOperand(1);
   Type *Ty = X->getType();
-  IRBuilder<> Builder(Orig->getParent());
-  Builder.SetInsertPoint(Orig);
+  IRBuilder<> Builder(Orig);
 
   auto *Log2Call =
       Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
@@ -350,8 +352,7 @@ static Value *expandClampIntrinsic(CallInst *Orig,
   Value *Min = Orig->getOperand(1);
   Value *Max = Orig->getOperand(2);
   Type *Ty = X->getType();
-  IRBuilder<> Builder(Orig->getParent());
-  Builder.SetInsertPoint(Orig);
+  IRBuilder<> Builder(Orig);
   auto *MaxCall = Builder.CreateIntrinsic(
       Ty, getMaxForClamp(Ty, ClampIntrinsic), {X, Min}, nullptr, "dx.max");
   return Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic),
@@ -360,7 +361,8 @@ static Value *expandClampIntrinsic(CallInst *Orig,
 
 static bool expandIntrinsic(Function &F, CallInst *Orig) {
   Value *Result = nullptr;
-  switch (F.getIntrinsicID()) {
+  Intrinsic::ID IntrinsicId = F.getIntrinsicID();
+  switch (IntrinsicId) {
   case Intrinsic::abs:
     Result = expandAbs(Orig);
     break;
@@ -376,12 +378,13 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
   case Intrinsic::pow:
     Result = expandPowIntrinsic(Orig);
     break;
+  case Intrinsic::dx_all:
   case Intrinsic::dx_any:
-    Result = expandAnyIntrinsic(Orig);
+    Result = expandAnyOrAllIntrinsic(Orig, IntrinsicId);
     break;
   case Intrinsic::dx_uclamp:
   case Intrinsic::dx_clamp:
-    Result = expandClampIntrinsic(Orig, F.getIntrinsicID());
+    Result = expandClampIntrinsic(Orig, IntrinsicId);
     break;
   case Intrinsic::dx_lerp:
     Result = expandLerpIntrinsic(Orig);
@@ -397,7 +400,7 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
     break;
   case Intrinsic::dx_sdot:
   case Intrinsic::dx_udot:
-    Result = expandIntegerDotIntrinsic(Orig, F.getIntrinsicID());
+    Result = expandIntegerDotIntrinsic(Orig, IntrinsicId);
     break;
   }
 
diff --git a/llvm/test/CodeGen/DirectX/all.ll b/llvm/test/CodeGen/DirectX/all.ll
new file mode 100644
index 00000000000000..c82d14f05ee640
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/all.ll
@@ -0,0 +1,113 @@
+; RUN: opt -S -passes=dxil-intrinsic-expansion,dxil-op-lower -mtriple=dxil-pc-shadermodel6.0-library < %s | FileCheck %s
+
+; Make sure dxil operation function calls for all are generated for float and half.
+
+; CHECK-LABEL: all_bool
+; CHECK: icmp ne i1 %{{.*}}, false
+define noundef i1 @all_bool(i1 noundef %p0) {
+entry:
+  %p0.addr = alloca i8, align 1
+  %frombool = zext i1 %p0 to i8
+  store i8 %frombool, ptr %p0.addr, align 1
+  %0 = load i8, ptr %p0.addr, align 1
+  %tobool = trunc i8 %0 to i1
+  %dx.all = call i1 @llvm.dx.all.i1(i1 %tobool)
+  ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_int64_t
+; CHECK: icmp ne i64 %{{.*}}, 0
+define noundef i1 @all_int64_t(i64 noundef %p0) {
+entry:
+  %p0.addr = alloca i64, align 8
+  store i64 %p0, ptr %p0.addr, align 8
+  %0 = load i64, ptr %p0.addr, align 8
+  %dx.all = call i1 @llvm.dx.all.i64(i64 %0)
+  ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_int
+; CHECK: icmp ne i32 %{{.*}}, 0
+define noundef i1 @all_int(i32 noundef %p0) {
+entry:
+  %p0.addr = alloca i32, align 4
+  store i32 %p0, ptr %p0.addr, align 4
+  %0 = load i32, ptr %p0.addr, align 4
+  %dx.all = call i1 @llvm.dx.all.i32(i32 %0)
+  ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_int16_t
+; CHECK: icmp ne i16 %{{.*}}, 0
+define noundef i1 @all_int16_t(i16 noundef %p0) {
+entry:
+  %p0.addr = alloca i16, align 2
+  store i16 %p0, ptr %p0.addr, align 2
+  %0 = load i16, ptr %p0.addr, align 2
+  %dx.all = call i1 @llvm.dx.all.i16(i16 %0)
+  ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_double
+; CHECK: fcmp une double %{{.*}}, 0.000000e+00
+define noundef i1 @all_double(double noundef %p0) {
+entry:
+  %p0.addr = alloca double, align 8
+  store double %p0, ptr %p0.addr, align 8
+  %0 = load double, ptr %p0.addr, align 8
+  %dx.all = call i1 @llvm.dx.all.f64(double %0)
+  ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_float
+; CHECK: fcmp une float %{{.*}}, 0.000000e+00
+define noundef i1 @all_float(float noundef %p0) {
+entry:
+  %p0.addr = alloca float, align 4
+  store float %p0, ptr %p0.addr, align 4
+  %0 = load float, ptr %p0.addr, align 4
+  %dx.all = call i1 @llvm.dx.all.f32(float %0)
+  ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_half
+; CHECK: fcmp une half %{{.*}}, 0xH0000
+define noundef i1 @all_half(half noundef %p0) {
+entry:
+  %p0.addr = alloca half, align 2
+  store half %p0, ptr %p0.addr, align 2
+  %0 = load half, ptr %p0.addr, align 2
+  %dx.all = call i1 @llvm.dx.all.f16(half %0)
+  ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_bool4
+; CHECK: icmp ne <4 x i1> %extractvec, zeroinitialize
+; CHECK: extractelement <4 x i1> %{{.*}}, i64 0
+; CHECK: extractelement <4 x i1> %{{.*}}, i64 1
+; CHECK: and i1  %{{.*}}, %{{.*}}
+; CHECK: extractelement <4 x i1> %{{.*}}, i64 2
+; CHECK: and i1  %{{.*}}, %{{.*}}
+; CHECK: extractelement <4 x i1> %{{.*}}, i64 3
+; CHECK: and i1  %{{.*}}, %{{.*}}
+define noundef i1 @all_bool4(<4 x i1> noundef %p0) {
+entry:
+  %p0.addr = alloca i8, align 1
+  %insertvec = shufflevector <4 x i1> %p0, <4 x i1> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison>
+  %0 = bitcast <8 x i1> %insertvec to i8
+  store i8 %0, ptr %p0.addr, align 1
+  %load_bits = load i8, ptr %p0.addr, align 1
+  %1 = bitcast i8 %load_bits to <8 x i1>
+  %extractvec = shufflevector <8 x i1> %1, <8 x i1> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %dx.all = call i1 @llvm.dx.all.v4i1(<4 x i1> %extractvec)
+  ret i1 %dx.all
+}
+
+declare i1 @llvm.dx.all.v4i1(<4 x i1>)
+declare i1 @llvm.dx.all.i1(i1)
+declare i1 @llvm.dx.all.i16(i16)
+declare i1 @llvm.dx.all.i32(i32)
+declare i1 @llvm.dx.all.i64(i64)
+declare i1 @llvm.dx.all.f16(half)
+declare i1 @llvm.dx.all.f32(float)
+declare i1 @llvm.dx.all.f64(double)



More information about the llvm-commits mailing list