[llvm] [DirectX] Add `all` lowering (PR #105787)
Farzon Lotfi via llvm-commits
llvm-commits at lists.llvm.org
Thu Aug 22 23:12:54 PDT 2024
https://github.com/farzonl created https://github.com/llvm/llvm-project/pull/105787
- DXILIntrinsicExpansion.cpp: Modify `any` codegen expansion to work for `all`
- DirectX\all.ll: Add test case
completes #88946
>From 91879fe1ba93bf461d73843789dcaf0aad86946e Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzon at farzon.org>
Date: Fri, 23 Aug 2024 00:01:33 -0400
Subject: [PATCH] [DirectX] Add All lowering - DXILIntrinsicExpansion.cpp:
Modify `any` codegen expansion to work for `all` - DirectX\all.ll: Add test
case
---
.../Target/DirectX/DXILIntrinsicExpansion.cpp | 51 ++++----
llvm/test/CodeGen/DirectX/all.ll | 113 ++++++++++++++++++
2 files changed, 140 insertions(+), 24 deletions(-)
create mode 100644 llvm/test/CodeGen/DirectX/all.ll
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index e49169cff8aa86..2daa4f825c3b25 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -38,6 +38,7 @@ static bool isIntrinsicExpansion(Function &F) {
case Intrinsic::log:
case Intrinsic::log10:
case Intrinsic::pow:
+ case Intrinsic::dx_all:
case Intrinsic::dx_any:
case Intrinsic::dx_clamp:
case Intrinsic::dx_uclamp:
@@ -54,8 +55,7 @@ static bool isIntrinsicExpansion(Function &F) {
static Value *expandAbs(CallInst *Orig) {
Value *X = Orig->getOperand(0);
- IRBuilder<> Builder(Orig->getParent());
- Builder.SetInsertPoint(Orig);
+ IRBuilder<> Builder(Orig);
Type *Ty = X->getType();
Type *EltTy = Ty->getScalarType();
Constant *Zero = Ty->isVectorTy()
@@ -148,8 +148,7 @@ static Value *expandIntegerDotIntrinsic(CallInst *Orig,
static Value *expandExpIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
- IRBuilder<> Builder(Orig->getParent());
- Builder.SetInsertPoint(Orig);
+ IRBuilder<> Builder(Orig);
Type *Ty = X->getType();
Type *EltTy = Ty->getScalarType();
Constant *Log2eConst =
@@ -166,13 +165,21 @@ static Value *expandExpIntrinsic(CallInst *Orig) {
return Exp2Call;
}
-static Value *expandAnyIntrinsic(CallInst *Orig) {
+static Value *expandAnyOrAllIntrinsic(CallInst *Orig,
+ Intrinsic::ID intrinsicId) {
Value *X = Orig->getOperand(0);
- IRBuilder<> Builder(Orig->getParent());
- Builder.SetInsertPoint(Orig);
+ IRBuilder<> Builder(Orig);
Type *Ty = X->getType();
Type *EltTy = Ty->getScalarType();
+ auto ApplyOp = [&Builder](Intrinsic::ID IntrinsicId, Value *Result,
+ Value *Elt) {
+ if (IntrinsicId == Intrinsic::dx_any)
+ return Builder.CreateOr(Result, Elt);
+ assert(IntrinsicId == Intrinsic::dx_all);
+ return Builder.CreateAnd(Result, Elt);
+ };
+
Value *Result = nullptr;
if (!Ty->isVectorTy()) {
Result = EltTy->isFloatingPointTy()
@@ -193,7 +200,7 @@ static Value *expandAnyIntrinsic(CallInst *Orig) {
Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
for (unsigned I = 1; I < XVec->getNumElements(); I++) {
Value *Elt = Builder.CreateExtractElement(Cond, I);
- Result = Builder.CreateOr(Result, Elt);
+ Result = ApplyOp(intrinsicId, Result, Elt);
}
}
return Result;
@@ -201,8 +208,7 @@ static Value *expandAnyIntrinsic(CallInst *Orig) {
static Value *expandLengthIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
- IRBuilder<> Builder(Orig->getParent());
- Builder.SetInsertPoint(Orig);
+ IRBuilder<> Builder(Orig);
Type *Ty = X->getType();
Type *EltTy = Ty->getScalarType();
@@ -230,8 +236,7 @@ static Value *expandLerpIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
Value *Y = Orig->getOperand(1);
Value *S = Orig->getOperand(2);
- IRBuilder<> Builder(Orig->getParent());
- Builder.SetInsertPoint(Orig);
+ IRBuilder<> Builder(Orig);
auto *V = Builder.CreateFSub(Y, X);
V = Builder.CreateFMul(S, V);
return Builder.CreateFAdd(X, V, "dx.lerp");
@@ -240,8 +245,7 @@ static Value *expandLerpIntrinsic(CallInst *Orig) {
static Value *expandLogIntrinsic(CallInst *Orig,
float LogConstVal = numbers::ln2f) {
Value *X = Orig->getOperand(0);
- IRBuilder<> Builder(Orig->getParent());
- Builder.SetInsertPoint(Orig);
+ IRBuilder<> Builder(Orig);
Type *Ty = X->getType();
Type *EltTy = Ty->getScalarType();
Constant *Ln2Const =
@@ -266,8 +270,7 @@ static Value *expandNormalizeIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
Type *Ty = Orig->getType();
Type *EltTy = Ty->getScalarType();
- IRBuilder<> Builder(Orig->getParent());
- Builder.SetInsertPoint(Orig);
+ IRBuilder<> Builder(Orig);
auto *XVec = dyn_cast<FixedVectorType>(Ty);
if (!XVec) {
@@ -305,8 +308,7 @@ static Value *expandPowIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
Value *Y = Orig->getOperand(1);
Type *Ty = X->getType();
- IRBuilder<> Builder(Orig->getParent());
- Builder.SetInsertPoint(Orig);
+ IRBuilder<> Builder(Orig);
auto *Log2Call =
Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
@@ -350,8 +352,7 @@ static Value *expandClampIntrinsic(CallInst *Orig,
Value *Min = Orig->getOperand(1);
Value *Max = Orig->getOperand(2);
Type *Ty = X->getType();
- IRBuilder<> Builder(Orig->getParent());
- Builder.SetInsertPoint(Orig);
+ IRBuilder<> Builder(Orig);
auto *MaxCall = Builder.CreateIntrinsic(
Ty, getMaxForClamp(Ty, ClampIntrinsic), {X, Min}, nullptr, "dx.max");
return Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic),
@@ -360,7 +361,8 @@ static Value *expandClampIntrinsic(CallInst *Orig,
static bool expandIntrinsic(Function &F, CallInst *Orig) {
Value *Result = nullptr;
- switch (F.getIntrinsicID()) {
+ Intrinsic::ID IntrinsicId = F.getIntrinsicID();
+ switch (IntrinsicId) {
case Intrinsic::abs:
Result = expandAbs(Orig);
break;
@@ -376,12 +378,13 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
case Intrinsic::pow:
Result = expandPowIntrinsic(Orig);
break;
+ case Intrinsic::dx_all:
case Intrinsic::dx_any:
- Result = expandAnyIntrinsic(Orig);
+ Result = expandAnyOrAllIntrinsic(Orig, IntrinsicId);
break;
case Intrinsic::dx_uclamp:
case Intrinsic::dx_clamp:
- Result = expandClampIntrinsic(Orig, F.getIntrinsicID());
+ Result = expandClampIntrinsic(Orig, IntrinsicId);
break;
case Intrinsic::dx_lerp:
Result = expandLerpIntrinsic(Orig);
@@ -397,7 +400,7 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
break;
case Intrinsic::dx_sdot:
case Intrinsic::dx_udot:
- Result = expandIntegerDotIntrinsic(Orig, F.getIntrinsicID());
+ Result = expandIntegerDotIntrinsic(Orig, IntrinsicId);
break;
}
diff --git a/llvm/test/CodeGen/DirectX/all.ll b/llvm/test/CodeGen/DirectX/all.ll
new file mode 100644
index 00000000000000..c82d14f05ee640
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/all.ll
@@ -0,0 +1,113 @@
+; RUN: opt -S -passes=dxil-intrinsic-expansion,dxil-op-lower -mtriple=dxil-pc-shadermodel6.0-library < %s | FileCheck %s
+
+; Make sure dxil operation function calls for all are generated for float and half.
+
+; CHECK-LABEL: all_bool
+; CHECK: icmp ne i1 %{{.*}}, false
+define noundef i1 @all_bool(i1 noundef %p0) {
+entry:
+ %p0.addr = alloca i8, align 1
+ %frombool = zext i1 %p0 to i8
+ store i8 %frombool, ptr %p0.addr, align 1
+ %0 = load i8, ptr %p0.addr, align 1
+ %tobool = trunc i8 %0 to i1
+ %dx.all = call i1 @llvm.dx.all.i1(i1 %tobool)
+ ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_int64_t
+; CHECK: icmp ne i64 %{{.*}}, 0
+define noundef i1 @all_int64_t(i64 noundef %p0) {
+entry:
+ %p0.addr = alloca i64, align 8
+ store i64 %p0, ptr %p0.addr, align 8
+ %0 = load i64, ptr %p0.addr, align 8
+ %dx.all = call i1 @llvm.dx.all.i64(i64 %0)
+ ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_int
+; CHECK: icmp ne i32 %{{.*}}, 0
+define noundef i1 @all_int(i32 noundef %p0) {
+entry:
+ %p0.addr = alloca i32, align 4
+ store i32 %p0, ptr %p0.addr, align 4
+ %0 = load i32, ptr %p0.addr, align 4
+ %dx.all = call i1 @llvm.dx.all.i32(i32 %0)
+ ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_int16_t
+; CHECK: icmp ne i16 %{{.*}}, 0
+define noundef i1 @all_int16_t(i16 noundef %p0) {
+entry:
+ %p0.addr = alloca i16, align 2
+ store i16 %p0, ptr %p0.addr, align 2
+ %0 = load i16, ptr %p0.addr, align 2
+ %dx.all = call i1 @llvm.dx.all.i16(i16 %0)
+ ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_double
+; CHECK: fcmp une double %{{.*}}, 0.000000e+00
+define noundef i1 @all_double(double noundef %p0) {
+entry:
+ %p0.addr = alloca double, align 8
+ store double %p0, ptr %p0.addr, align 8
+ %0 = load double, ptr %p0.addr, align 8
+ %dx.all = call i1 @llvm.dx.all.f64(double %0)
+ ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_float
+; CHECK: fcmp une float %{{.*}}, 0.000000e+00
+define noundef i1 @all_float(float noundef %p0) {
+entry:
+ %p0.addr = alloca float, align 4
+ store float %p0, ptr %p0.addr, align 4
+ %0 = load float, ptr %p0.addr, align 4
+ %dx.all = call i1 @llvm.dx.all.f32(float %0)
+ ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_half
+; CHECK: fcmp une half %{{.*}}, 0xH0000
+define noundef i1 @all_half(half noundef %p0) {
+entry:
+ %p0.addr = alloca half, align 2
+ store half %p0, ptr %p0.addr, align 2
+ %0 = load half, ptr %p0.addr, align 2
+ %dx.all = call i1 @llvm.dx.all.f16(half %0)
+ ret i1 %dx.all
+}
+
+; CHECK-LABEL: all_bool4
+; CHECK: icmp ne <4 x i1> %extractvec, zeroinitialize
+; CHECK: extractelement <4 x i1> %{{.*}}, i64 0
+; CHECK: extractelement <4 x i1> %{{.*}}, i64 1
+; CHECK: and i1 %{{.*}}, %{{.*}}
+; CHECK: extractelement <4 x i1> %{{.*}}, i64 2
+; CHECK: and i1 %{{.*}}, %{{.*}}
+; CHECK: extractelement <4 x i1> %{{.*}}, i64 3
+; CHECK: and i1 %{{.*}}, %{{.*}}
+define noundef i1 @all_bool4(<4 x i1> noundef %p0) {
+entry:
+ %p0.addr = alloca i8, align 1
+ %insertvec = shufflevector <4 x i1> %p0, <4 x i1> poison, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison>
+ %0 = bitcast <8 x i1> %insertvec to i8
+ store i8 %0, ptr %p0.addr, align 1
+ %load_bits = load i8, ptr %p0.addr, align 1
+ %1 = bitcast i8 %load_bits to <8 x i1>
+ %extractvec = shufflevector <8 x i1> %1, <8 x i1> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+ %dx.all = call i1 @llvm.dx.all.v4i1(<4 x i1> %extractvec)
+ ret i1 %dx.all
+}
+
+declare i1 @llvm.dx.all.v4i1(<4 x i1>)
+declare i1 @llvm.dx.all.i1(i1)
+declare i1 @llvm.dx.all.i16(i16)
+declare i1 @llvm.dx.all.i32(i32)
+declare i1 @llvm.dx.all.i64(i64)
+declare i1 @llvm.dx.all.f16(half)
+declare i1 @llvm.dx.all.f32(float)
+declare i1 @llvm.dx.all.f64(double)
More information about the llvm-commits
mailing list