[llvm] ff5816a - [DirectX] Add `all` lowering (#105787)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Aug 26 10:40:15 PDT 2024
Author: Farzon Lotfi
Date: 2024-08-26T13:40:11-04:00
New Revision: ff5816ad29eba3762e1c5c576c1adf586c35dd91
URL: https://github.com/llvm/llvm-project/commit/ff5816ad29eba3762e1c5c576c1adf586c35dd91
DIFF: https://github.com/llvm/llvm-project/commit/ff5816ad29eba3762e1c5c576c1adf586c35dd91.diff
LOG: [DirectX] Add `all` lowering (#105787)
- DXILIntrinsicExpansion.cpp: Modify `any` codegen expansion to work for
`all`
- DirectX\all.ll: Add test case
completes #88946
Added:
llvm/test/CodeGen/DirectX/all.ll
Modified:
llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index e49169cff8aa86..2daa4f825c3b25 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -38,6 +38,7 @@ static bool isIntrinsicExpansion(Function &F) {
case Intrinsic::log:
case Intrinsic::log10:
case Intrinsic::pow:
+ case Intrinsic::dx_all:
case Intrinsic::dx_any:
case Intrinsic::dx_clamp:
case Intrinsic::dx_uclamp:
@@ -54,8 +55,7 @@ static bool isIntrinsicExpansion(Function &F) {
static Value *expandAbs(CallInst *Orig) {
Value *X = Orig->getOperand(0);
- IRBuilder<> Builder(Orig->getParent());
- Builder.SetInsertPoint(Orig);
+ IRBuilder<> Builder(Orig);
Type *Ty = X->getType();
Type *EltTy = Ty->getScalarType();
Constant *Zero = Ty->isVectorTy()
@@ -148,8 +148,7 @@ static Value *expandIntegerDotIntrinsic(CallInst *Orig,
static Value *expandExpIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
- IRBuilder<> Builder(Orig->getParent());
- Builder.SetInsertPoint(Orig);
+ IRBuilder<> Builder(Orig);
Type *Ty = X->getType();
Type *EltTy = Ty->getScalarType();
Constant *Log2eConst =
@@ -166,13 +165,21 @@ static Value *expandExpIntrinsic(CallInst *Orig) {
return Exp2Call;
}
-static Value *expandAnyIntrinsic(CallInst *Orig) {
+static Value *expandAnyOrAllIntrinsic(CallInst *Orig,
+ Intrinsic::ID intrinsicId) {
Value *X = Orig->getOperand(0);
- IRBuilder<> Builder(Orig->getParent());
- Builder.SetInsertPoint(Orig);
+ IRBuilder<> Builder(Orig);
Type *Ty = X->getType();
Type *EltTy = Ty->getScalarType();
+ auto ApplyOp = [&Builder](Intrinsic::ID IntrinsicId, Value *Result,
+ Value *Elt) {
+ if (IntrinsicId == Intrinsic::dx_any)
+ return Builder.CreateOr(Result, Elt);
+ assert(IntrinsicId == Intrinsic::dx_all);
+ return Builder.CreateAnd(Result, Elt);
+ };
+
Value *Result = nullptr;
if (!Ty->isVectorTy()) {
Result = EltTy->isFloatingPointTy()
@@ -193,7 +200,7 @@ static Value *expandAnyIntrinsic(CallInst *Orig) {
Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
for (unsigned I = 1; I < XVec->getNumElements(); I++) {
Value *Elt = Builder.CreateExtractElement(Cond, I);
- Result = Builder.CreateOr(Result, Elt);
+ Result = ApplyOp(intrinsicId, Result, Elt);
}
}
return Result;
@@ -201,8 +208,7 @@ static Value *expandAnyIntrinsic(CallInst *Orig) {
static Value *expandLengthIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
- IRBuilder<> Builder(Orig->getParent());
- Builder.SetInsertPoint(Orig);
+ IRBuilder<> Builder(Orig);
Type *Ty = X->getType();
Type *EltTy = Ty->getScalarType();
@@ -230,8 +236,7 @@ static Value *expandLerpIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
Value *Y = Orig->getOperand(1);
Value *S = Orig->getOperand(2);
- IRBuilder<> Builder(Orig->getParent());
- Builder.SetInsertPoint(Orig);
+ IRBuilder<> Builder(Orig);
auto *V = Builder.CreateFSub(Y, X);
V = Builder.CreateFMul(S, V);
return Builder.CreateFAdd(X, V, "dx.lerp");
@@ -240,8 +245,7 @@ static Value *expandLerpIntrinsic(CallInst *Orig) {
static Value *expandLogIntrinsic(CallInst *Orig,
float LogConstVal = numbers::ln2f) {
Value *X = Orig->getOperand(0);
- IRBuilder<> Builder(Orig->getParent());
- Builder.SetInsertPoint(Orig);
+ IRBuilder<> Builder(Orig);
Type *Ty = X->getType();
Type *EltTy = Ty->getScalarType();
Constant *Ln2Const =
@@ -266,8 +270,7 @@ static Value *expandNormalizeIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
Type *Ty = Orig->getType();
Type *EltTy = Ty->getScalarType();
- IRBuilder<> Builder(Orig->getParent());
- Builder.SetInsertPoint(Orig);
+ IRBuilder<> Builder(Orig);
auto *XVec = dyn_cast<FixedVectorType>(Ty);
if (!XVec) {
@@ -305,8 +308,7 @@ static Value *expandPowIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
Value *Y = Orig->getOperand(1);
Type *Ty = X->getType();
- IRBuilder<> Builder(Orig->getParent());
- Builder.SetInsertPoint(Orig);
+ IRBuilder<> Builder(Orig);
auto *Log2Call =
Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
@@ -350,8 +352,7 @@ static Value *expandClampIntrinsic(CallInst *Orig,
Value *Min = Orig->getOperand(1);
Value *Max = Orig->getOperand(2);
Type *Ty = X->getType();
- IRBuilder<> Builder(Orig->getParent());
- Builder.SetInsertPoint(Orig);
+ IRBuilder<> Builder(Orig);
auto *MaxCall = Builder.CreateIntrinsic(
Ty, getMaxForClamp(Ty, ClampIntrinsic), {X, Min}, nullptr, "dx.max");
return Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic),
@@ -360,7 +361,8 @@ static Value *expandClampIntrinsic(CallInst *Orig,
static bool expandIntrinsic(Function &F, CallInst *Orig) {
Value *Result = nullptr;
- switch (F.getIntrinsicID()) {
+ Intrinsic::ID IntrinsicId = F.getIntrinsicID();
+ switch (IntrinsicId) {
case Intrinsic::abs:
Result = expandAbs(Orig);
break;
@@ -376,12 +378,13 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
case Intrinsic::pow:
Result = expandPowIntrinsic(Orig);
break;
+ case Intrinsic::dx_all:
case Intrinsic::dx_any:
- Result = expandAnyIntrinsic(Orig);
+ Result = expandAnyOrAllIntrinsic(Orig, IntrinsicId);
break;
case Intrinsic::dx_uclamp:
case Intrinsic::dx_clamp:
- Result = expandClampIntrinsic(Orig, F.getIntrinsicID());
+ Result = expandClampIntrinsic(Orig, IntrinsicId);
break;
case Intrinsic::dx_lerp:
Result = expandLerpIntrinsic(Orig);
@@ -397,7 +400,7 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
break;
case Intrinsic::dx_sdot:
case Intrinsic::dx_udot:
- Result = expandIntegerDotIntrinsic(Orig, F.getIntrinsicID());
+ Result = expandIntegerDotIntrinsic(Orig, IntrinsicId);
break;
}
diff --git a/llvm/test/CodeGen/DirectX/all.ll b/llvm/test/CodeGen/DirectX/all.ll
new file mode 100644
index 00000000000000..1c0b6486dc9358
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/all.ll
@@ -0,0 +1,83 @@
+; RUN: opt -S -passes=dxil-intrinsic-expansion,dxil-op-lower -mtriple=dxil-pc-shadermodel6.0-library < %s | FileCheck %s
+
+; Make sure dxil operation function calls for all are generated for float and half.
+
+; CHECK-LABEL: all_bool
+; CHECK: icmp ne i1 %{{.*}}, false
+define noundef i1 @all_bool(i1 noundef %p0) {
+entry:
+ %dx.all = call i1 @llvm.dx.all.i1(i1 %p0)
+ ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_int64_t
+; CHECK: icmp ne i64 %{{.*}}, 0
+define noundef i1 @all_int64_t(i64 noundef %p0) {
+entry:
+ %dx.all = call i1 @llvm.dx.all.i64(i64 %p0)
+ ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_int
+; CHECK: icmp ne i32 %{{.*}}, 0
+define noundef i1 @all_int(i32 noundef %p0) {
+entry:
+ %dx.all = call i1 @llvm.dx.all.i32(i32 %p0)
+ ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_int16_t
+; CHECK: icmp ne i16 %{{.*}}, 0
+define noundef i1 @all_int16_t(i16 noundef %p0) {
+entry:
+ %dx.all = call i1 @llvm.dx.all.i16(i16 %p0)
+ ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_double
+; CHECK: fcmp une double %{{.*}}, 0.000000e+00
+define noundef i1 @all_double(double noundef %p0) {
+entry:
+ %dx.all = call i1 @llvm.dx.all.f64(double %p0)
+ ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_float
+; CHECK: fcmp une float %{{.*}}, 0.000000e+00
+define noundef i1 @all_float(float noundef %p0) {
+entry:
+ %dx.all = call i1 @llvm.dx.all.f32(float %p0)
+ ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_half
+; CHECK: fcmp une half %{{.*}}, 0xH0000
+define noundef i1 @all_half(half noundef %p0) {
+entry:
+ %dx.all = call i1 @llvm.dx.all.f16(half %p0)
+ ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_bool4
+; CHECK: icmp ne <4 x i1> %{{.*}}, zeroinitialize
+; CHECK: extractelement <4 x i1> %{{.*}}, i64 0
+; CHECK: extractelement <4 x i1> %{{.*}}, i64 1
+; CHECK: and i1 %{{.*}}, %{{.*}}
+; CHECK: extractelement <4 x i1> %{{.*}}, i64 2
+; CHECK: and i1 %{{.*}}, %{{.*}}
+; CHECK: extractelement <4 x i1> %{{.*}}, i64 3
+; CHECK: and i1 %{{.*}}, %{{.*}}
+define noundef i1 @all_bool4(<4 x i1> noundef %p0) {
+entry:
+ %dx.all = call i1 @llvm.dx.all.v4i1(<4 x i1> %p0)
+ ret i1 %dx.all
+}
+
+declare i1 @llvm.dx.all.v4i1(<4 x i1>)
+declare i1 @llvm.dx.all.i1(i1)
+declare i1 @llvm.dx.all.i16(i16)
+declare i1 @llvm.dx.all.i32(i32)
+declare i1 @llvm.dx.all.i64(i64)
+declare i1 @llvm.dx.all.f16(half)
+declare i1 @llvm.dx.all.f32(float)
+declare i1 @llvm.dx.all.f64(double)
More information about the llvm-commits
mailing list