[llvm] [DirectX] Add support for vector_reduce_add (PR #117646)
Farzon Lotfi via llvm-commits
llvm-commits at lists.llvm.org
Mon Dec 2 12:29:10 PST 2024
https://github.com/farzonl updated https://github.com/llvm/llvm-project/pull/117646
>From b5bbe3b64e95a263937a6b88ce2c381bebe93a4e Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Mon, 25 Nov 2024 18:49:04 -0500
Subject: [PATCH 1/2] [DirectX] Add support for vector_reduce_add
---
.../Target/DirectX/DXILIntrinsicExpansion.cpp | 39 +++
.../test/CodeGen/DirectX/vector_reduce_add.ll | 292 ++++++++++++++++++
2 files changed, 331 insertions(+)
create mode 100644 llvm/test/CodeGen/DirectX/vector_reduce_add.ll
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index d2bfca1fada559..9346700f502701 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -67,11 +67,44 @@ static bool isIntrinsicExpansion(Function &F) {
case Intrinsic::dx_sign:
case Intrinsic::dx_step:
case Intrinsic::dx_radians:
+ case Intrinsic::vector_reduce_add:
+ case Intrinsic::vector_reduce_fadd:
return true;
}
return false;
}
+static Value *expandVecReduceFAdd(CallInst *Orig) {
+ // Note: vector_reduce_fadd first argument is a starting value
+ // Our use doesn't need it, so ignoring argument zero.
+ Value *X = Orig->getOperand(1);
+ IRBuilder<> Builder(Orig);
+ Type *Ty = X->getType();
+ auto *XVec = dyn_cast<FixedVectorType>(Ty);
+ unsigned XVecSize = XVec->getNumElements();
+ Value *Sum = Builder.CreateExtractElement(X, static_cast<uint64_t>(0));
+ for (unsigned I = 1; I < XVecSize; I++) {
+ Value *Elt = Builder.CreateExtractElement(X, I);
+ Sum = Builder.CreateFAdd(Sum, Elt);
+ }
+ return Sum;
+}
+
+static Value *expandVecReduceAdd(CallInst *Orig) {
+ Value *X = Orig->getOperand(0);
+ IRBuilder<> Builder(Orig);
+ Type *Ty = X->getType();
+ auto *XVec = dyn_cast<FixedVectorType>(Ty);
+ unsigned XVecSize = XVec->getNumElements();
+
+ Value *Sum = Builder.CreateExtractElement(X, static_cast<uint64_t>(0));
+ for (unsigned I = 1; I < XVecSize; I++) {
+ Value *Elt = Builder.CreateExtractElement(X, I);
+ Sum = Builder.CreateAdd(Sum, Elt);
+ }
+ return Sum;
+}
+
static Value *expandAbs(CallInst *Orig) {
Value *X = Orig->getOperand(0);
IRBuilder<> Builder(Orig);
@@ -580,6 +613,12 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
case Intrinsic::dx_radians:
Result = expandRadiansIntrinsic(Orig);
break;
+ case Intrinsic::vector_reduce_add:
+ Result = expandVecReduceAdd(Orig);
+ break;
+ case Intrinsic::vector_reduce_fadd:
+ Result = expandVecReduceFAdd(Orig);
+ break;
}
if (Result) {
Orig->replaceAllUsesWith(Result);
diff --git a/llvm/test/CodeGen/DirectX/vector_reduce_add.ll b/llvm/test/CodeGen/DirectX/vector_reduce_add.ll
new file mode 100644
index 00000000000000..96d85534f2cb05
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/vector_reduce_add.ll
@@ -0,0 +1,292 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -dxil-intrinsic-expansion < %s | FileCheck %s
+
+; Make sure dxil operation function calls for lvm.vector.reduce.fadd and lvm.vector.reduce.add are generate.
+
+define noundef half @test_length_half2(<2 x half> noundef %p0) {
+; CHECK-LABEL: define noundef half @test_length_half2(
+; CHECK-SAME: <2 x half> noundef [[P0:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x half> [[P0]], i64 0
+; CHECK-NEXT: [[TMP1:%.*]] = extractelement <2 x half> [[P0]], i64 1
+; CHECK-NEXT: [[TMP2:%.*]] = fadd half [[TMP0]], [[TMP1]]
+; CHECK-NEXT: ret half [[TMP2]]
+;
+entry:
+ %rdx.fadd = call half @llvm.vector.reduce.fadd.v2f16(half 0xH0000, <2 x half> %p0)
+ ret half %rdx.fadd
+}
+
+define noundef half @test_length_half3(<3 x half> noundef %p0) {
+; CHECK-LABEL: define noundef half @test_length_half3(
+; CHECK-SAME: <3 x half> noundef [[P0:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = extractelement <3 x half> [[P0]], i64 0
+; CHECK-NEXT: [[TMP1:%.*]] = extractelement <3 x half> [[P0]], i64 1
+; CHECK-NEXT: [[TMP2:%.*]] = fadd half [[TMP0]], [[TMP1]]
+; CHECK-NEXT: [[TMP3:%.*]] = extractelement <3 x half> [[P0]], i64 2
+; CHECK-NEXT: [[TMP4:%.*]] = fadd half [[TMP2]], [[TMP3]]
+; CHECK-NEXT: ret half [[TMP4]]
+;
+entry:
+ %rdx.fadd = call half @llvm.vector.reduce.fadd.v3f16(half 0xH0000, <3 x half> %p0)
+ ret half %rdx.fadd
+}
+
+define noundef half @test_length_half4(<4 x half> noundef %p0) {
+; CHECK-LABEL: define noundef half @test_length_half4(
+; CHECK-SAME: <4 x half> noundef [[P0:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = extractelement <4 x half> [[P0]], i64 0
+; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x half> [[P0]], i64 1
+; CHECK-NEXT: [[TMP2:%.*]] = fadd half [[TMP0]], [[TMP1]]
+; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x half> [[P0]], i64 2
+; CHECK-NEXT: [[TMP4:%.*]] = fadd half [[TMP2]], [[TMP3]]
+; CHECK-NEXT: [[TMP5:%.*]] = extractelement <4 x half> [[P0]], i64 3
+; CHECK-NEXT: [[TMP6:%.*]] = fadd half [[TMP4]], [[TMP5]]
+; CHECK-NEXT: ret half [[TMP6]]
+;
+entry:
+ %rdx.fadd = call half @llvm.vector.reduce.fadd.v4f16(half 0xH0000, <4 x half> %p0)
+ ret half %rdx.fadd
+}
+
+define noundef float @test_length_float2(<2 x float> noundef %p0) {
+; CHECK-LABEL: define noundef float @test_length_float2(
+; CHECK-SAME: <2 x float> noundef [[P0:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x float> [[P0]], i64 0
+; CHECK-NEXT: [[TMP1:%.*]] = extractelement <2 x float> [[P0]], i64 1
+; CHECK-NEXT: [[TMP2:%.*]] = fadd float [[TMP0]], [[TMP1]]
+; CHECK-NEXT: ret float [[TMP2]]
+;
+entry:
+ %rdx.fadd = call float @llvm.vector.reduce.fadd.v2f32(float 0.000000e+00, <2 x float> %p0)
+ ret float %rdx.fadd
+}
+
+define noundef float @test_length_float3(<3 x float> noundef %p0) {
+; CHECK-LABEL: define noundef float @test_length_float3(
+; CHECK-SAME: <3 x float> noundef [[P0:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = extractelement <3 x float> [[P0]], i64 0
+; CHECK-NEXT: [[TMP1:%.*]] = extractelement <3 x float> [[P0]], i64 1
+; CHECK-NEXT: [[TMP2:%.*]] = fadd float [[TMP0]], [[TMP1]]
+; CHECK-NEXT: [[TMP3:%.*]] = extractelement <3 x float> [[P0]], i64 2
+; CHECK-NEXT: [[TMP4:%.*]] = fadd float [[TMP2]], [[TMP3]]
+; CHECK-NEXT: ret float [[TMP4]]
+;
+entry:
+ %rdx.fadd = call float @llvm.vector.reduce.fadd.v3f32(float 0.000000e+00, <3 x float> %p0)
+ ret float %rdx.fadd
+}
+
+define noundef float @test_length_float4(<4 x float> noundef %p0) {
+; CHECK-LABEL: define noundef float @test_length_float4(
+; CHECK-SAME: <4 x float> noundef [[P0:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = extractelement <4 x float> [[P0]], i64 0
+; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> [[P0]], i64 1
+; CHECK-NEXT: [[TMP2:%.*]] = fadd float [[TMP0]], [[TMP1]]
+; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x float> [[P0]], i64 2
+; CHECK-NEXT: [[TMP4:%.*]] = fadd float [[TMP2]], [[TMP3]]
+; CHECK-NEXT: [[TMP5:%.*]] = extractelement <4 x float> [[P0]], i64 3
+; CHECK-NEXT: [[TMP6:%.*]] = fadd float [[TMP4]], [[TMP5]]
+; CHECK-NEXT: ret float [[TMP6]]
+;
+entry:
+ %rdx.fadd = call float @llvm.vector.reduce.fadd.v4f32(float 0.000000e+00, <4 x float> %p0)
+ ret float %rdx.fadd
+}
+
+define noundef double @test_length_double2(<2 x double> noundef %p0) {
+; CHECK-LABEL: define noundef double @test_length_double2(
+; CHECK-SAME: <2 x double> noundef [[P0:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x double> [[P0]], i64 0
+; CHECK-NEXT: [[TMP1:%.*]] = extractelement <2 x double> [[P0]], i64 1
+; CHECK-NEXT: [[TMP2:%.*]] = fadd double [[TMP0]], [[TMP1]]
+; CHECK-NEXT: ret double [[TMP2]]
+;
+entry:
+ %rdx.fadd = call double @llvm.vector.reduce.fadd.v2f64(double 0.000000e+00, <2 x double> %p0)
+ ret double %rdx.fadd
+}
+
+define noundef double @test_length_double3(<3 x double> noundef %p0) {
+; CHECK-LABEL: define noundef double @test_length_double3(
+; CHECK-SAME: <3 x double> noundef [[P0:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = extractelement <3 x double> [[P0]], i64 0
+; CHECK-NEXT: [[TMP1:%.*]] = extractelement <3 x double> [[P0]], i64 1
+; CHECK-NEXT: [[TMP2:%.*]] = fadd double [[TMP0]], [[TMP1]]
+; CHECK-NEXT: [[TMP3:%.*]] = extractelement <3 x double> [[P0]], i64 2
+; CHECK-NEXT: [[TMP4:%.*]] = fadd double [[TMP2]], [[TMP3]]
+; CHECK-NEXT: ret double [[TMP4]]
+;
+entry:
+ %rdx.fadd = call double @llvm.vector.reduce.fadd.v3f64(double 0.000000e+00, <3 x double> %p0)
+ ret double %rdx.fadd
+}
+
+define noundef double @test_length_double4(<4 x double> noundef %p0) {
+; CHECK-LABEL: define noundef double @test_length_double4(
+; CHECK-SAME: <4 x double> noundef [[P0:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = extractelement <4 x double> [[P0]], i64 0
+; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x double> [[P0]], i64 1
+; CHECK-NEXT: [[TMP2:%.*]] = fadd double [[TMP0]], [[TMP1]]
+; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x double> [[P0]], i64 2
+; CHECK-NEXT: [[TMP4:%.*]] = fadd double [[TMP2]], [[TMP3]]
+; CHECK-NEXT: [[TMP5:%.*]] = extractelement <4 x double> [[P0]], i64 3
+; CHECK-NEXT: [[TMP6:%.*]] = fadd double [[TMP4]], [[TMP5]]
+; CHECK-NEXT: ret double [[TMP6]]
+;
+entry:
+ %rdx.fadd = call double @llvm.vector.reduce.fadd.v4f64(double 0.000000e+00, <4 x double> %p0)
+ ret double %rdx.fadd
+}
+
+define noundef i16 @test_length_short2(<2 x i16> noundef %p0) {
+; CHECK-LABEL: define noundef i16 @test_length_short2(
+; CHECK-SAME: <2 x i16> noundef [[P0:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x i16> [[P0]], i64 0
+; CHECK-NEXT: [[TMP1:%.*]] = extractelement <2 x i16> [[P0]], i64 1
+; CHECK-NEXT: [[TMP2:%.*]] = add i16 [[TMP0]], [[TMP1]]
+; CHECK-NEXT: ret i16 [[TMP2]]
+;
+entry:
+ %rdx.add = call i16 @llvm.vector.reduce.add.v2i16(<2 x i16> %p0)
+ ret i16 %rdx.add
+}
+
+define noundef i16 @test_length_short3(<3 x i16> noundef %p0) {
+; CHECK-LABEL: define noundef i16 @test_length_short3(
+; CHECK-SAME: <3 x i16> noundef [[P0:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = extractelement <3 x i16> [[P0]], i64 0
+; CHECK-NEXT: [[TMP1:%.*]] = extractelement <3 x i16> [[P0]], i64 1
+; CHECK-NEXT: [[TMP2:%.*]] = add i16 [[TMP0]], [[TMP1]]
+; CHECK-NEXT: [[TMP3:%.*]] = extractelement <3 x i16> [[P0]], i64 2
+; CHECK-NEXT: [[TMP4:%.*]] = add i16 [[TMP2]], [[TMP3]]
+; CHECK-NEXT: ret i16 [[TMP4]]
+;
+entry:
+ %rdx.fadd = call i16 @llvm.vector.reduce.add.v3i16(<3 x i16> %p0)
+ ret i16 %rdx.fadd
+}
+
+define noundef i16 @test_length_short4(<4 x i16> noundef %p0) {
+; CHECK-LABEL: define noundef i16 @test_length_short4(
+; CHECK-SAME: <4 x i16> noundef [[P0:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = extractelement <4 x i16> [[P0]], i64 0
+; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x i16> [[P0]], i64 1
+; CHECK-NEXT: [[TMP2:%.*]] = add i16 [[TMP0]], [[TMP1]]
+; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x i16> [[P0]], i64 2
+; CHECK-NEXT: [[TMP4:%.*]] = add i16 [[TMP2]], [[TMP3]]
+; CHECK-NEXT: [[TMP5:%.*]] = extractelement <4 x i16> [[P0]], i64 3
+; CHECK-NEXT: [[TMP6:%.*]] = add i16 [[TMP4]], [[TMP5]]
+; CHECK-NEXT: ret i16 [[TMP6]]
+;
+entry:
+ %rdx.fadd = call i16 @llvm.vector.reduce.add.v4i16(<4 x i16> %p0)
+ ret i16 %rdx.fadd
+}
+
+define noundef i32 @test_length_int2(<2 x i32> noundef %p0) {
+; CHECK-LABEL: define noundef i32 @test_length_int2(
+; CHECK-SAME: <2 x i32> noundef [[P0:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x i32> [[P0]], i64 0
+; CHECK-NEXT: [[TMP1:%.*]] = extractelement <2 x i32> [[P0]], i64 1
+; CHECK-NEXT: [[TMP2:%.*]] = add i32 [[TMP0]], [[TMP1]]
+; CHECK-NEXT: ret i32 [[TMP2]]
+;
+entry:
+ %rdx.add = call i32 @llvm.vector.reduce.add.v2i32(<2 x i32> %p0)
+ ret i32 %rdx.add
+}
+
+define noundef i32 @test_length_int3(<3 x i32> noundef %p0) {
+; CHECK-LABEL: define noundef i32 @test_length_int3(
+; CHECK-SAME: <3 x i32> noundef [[P0:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = extractelement <3 x i32> [[P0]], i64 0
+; CHECK-NEXT: [[TMP1:%.*]] = extractelement <3 x i32> [[P0]], i64 1
+; CHECK-NEXT: [[TMP2:%.*]] = add i32 [[TMP0]], [[TMP1]]
+; CHECK-NEXT: [[TMP3:%.*]] = extractelement <3 x i32> [[P0]], i64 2
+; CHECK-NEXT: [[TMP4:%.*]] = add i32 [[TMP2]], [[TMP3]]
+; CHECK-NEXT: ret i32 [[TMP4]]
+;
+entry:
+ %rdx.fadd = call i32 @llvm.vector.reduce.add.v3i32(<3 x i32> %p0)
+ ret i32 %rdx.fadd
+}
+
+define noundef i32 @test_length_int4(<4 x i32> noundef %p0) {
+; CHECK-LABEL: define noundef i32 @test_length_int4(
+; CHECK-SAME: <4 x i32> noundef [[P0:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = extractelement <4 x i32> [[P0]], i64 0
+; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x i32> [[P0]], i64 1
+; CHECK-NEXT: [[TMP2:%.*]] = add i32 [[TMP0]], [[TMP1]]
+; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x i32> [[P0]], i64 2
+; CHECK-NEXT: [[TMP4:%.*]] = add i32 [[TMP2]], [[TMP3]]
+; CHECK-NEXT: [[TMP5:%.*]] = extractelement <4 x i32> [[P0]], i64 3
+; CHECK-NEXT: [[TMP6:%.*]] = add i32 [[TMP4]], [[TMP5]]
+; CHECK-NEXT: ret i32 [[TMP6]]
+;
+entry:
+ %rdx.fadd = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %p0)
+ ret i32 %rdx.fadd
+}
+
+define noundef i64 @test_length_int64_2(<2 x i64> noundef %p0) {
+; CHECK-LABEL: define noundef i64 @test_length_int64_2(
+; CHECK-SAME: <2 x i64> noundef [[P0:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x i64> [[P0]], i64 0
+; CHECK-NEXT: [[TMP1:%.*]] = extractelement <2 x i64> [[P0]], i64 1
+; CHECK-NEXT: [[TMP2:%.*]] = add i64 [[TMP0]], [[TMP1]]
+; CHECK-NEXT: ret i64 [[TMP2]]
+;
+entry:
+ %rdx.add = call i64 @llvm.vector.reduce.add.v2i64(<2 x i64> %p0)
+ ret i64 %rdx.add
+}
+
+define noundef i64 @test_length_int64_3(<3 x i64> noundef %p0) {
+; CHECK-LABEL: define noundef i64 @test_length_int64_3(
+; CHECK-SAME: <3 x i64> noundef [[P0:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = extractelement <3 x i64> [[P0]], i64 0
+; CHECK-NEXT: [[TMP1:%.*]] = extractelement <3 x i64> [[P0]], i64 1
+; CHECK-NEXT: [[TMP2:%.*]] = add i64 [[TMP0]], [[TMP1]]
+; CHECK-NEXT: [[TMP3:%.*]] = extractelement <3 x i64> [[P0]], i64 2
+; CHECK-NEXT: [[TMP4:%.*]] = add i64 [[TMP2]], [[TMP3]]
+; CHECK-NEXT: ret i64 [[TMP4]]
+;
+entry:
+ %rdx.fadd = call i64 @llvm.vector.reduce.add.v3i64(<3 x i64> %p0)
+ ret i64 %rdx.fadd
+}
+
+define noundef i64 @test_length_int64_4(<4 x i64> noundef %p0) {
+; CHECK-LABEL: define noundef i64 @test_length_int64_4(
+; CHECK-SAME: <4 x i64> noundef [[P0:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = extractelement <4 x i64> [[P0]], i64 0
+; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x i64> [[P0]], i64 1
+; CHECK-NEXT: [[TMP2:%.*]] = add i64 [[TMP0]], [[TMP1]]
+; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x i64> [[P0]], i64 2
+; CHECK-NEXT: [[TMP4:%.*]] = add i64 [[TMP2]], [[TMP3]]
+; CHECK-NEXT: [[TMP5:%.*]] = extractelement <4 x i64> [[P0]], i64 3
+; CHECK-NEXT: [[TMP6:%.*]] = add i64 [[TMP4]], [[TMP5]]
+; CHECK-NEXT: ret i64 [[TMP6]]
+;
+entry:
+ %rdx.fadd = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> %p0)
+ ret i64 %rdx.fadd
+}
>From af564009ae0ad79d6294f6fb95dfd40eafc2ed20 Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Mon, 2 Dec 2024 15:21:00 -0500
Subject: [PATCH 2/2] merge add and fadd
---
.../Target/DirectX/DXILIntrinsicExpansion.cpp | 43 ++++++++--------
.../test/CodeGen/DirectX/vector_reduce_add.ll | 51 +++++++++++++++++++
2 files changed, 73 insertions(+), 21 deletions(-)
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index 9346700f502701..1242ee2a3abccb 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -73,35 +73,38 @@ static bool isIntrinsicExpansion(Function &F) {
}
return false;
}
+static Value *expandVecReduceAdd(CallInst *Orig, Intrinsic::ID IntrinsicId) {
+ assert(IntrinsicId == Intrinsic::vector_reduce_add ||
+ IntrinsicId == Intrinsic::vector_reduce_fadd);
-static Value *expandVecReduceFAdd(CallInst *Orig) {
- // Note: vector_reduce_fadd first argument is a starting value
- // Our use doesn't need it, so ignoring argument zero.
- Value *X = Orig->getOperand(1);
IRBuilder<> Builder(Orig);
+ bool IsFAdd = (IntrinsicId == Intrinsic::vector_reduce_fadd);
+
+ // Define the addition operation based on the intrinsic ID.
+ auto AddOp = [&Builder, IsFAdd](Value *Sum, Value *Elt) {
+ return IsFAdd ? Builder.CreateFAdd(Sum, Elt) : Builder.CreateAdd(Sum, Elt);
+ };
+
+ Value *X = Orig->getOperand(IsFAdd ? 1 : 0);
Type *Ty = X->getType();
auto *XVec = dyn_cast<FixedVectorType>(Ty);
unsigned XVecSize = XVec->getNumElements();
Value *Sum = Builder.CreateExtractElement(X, static_cast<uint64_t>(0));
- for (unsigned I = 1; I < XVecSize; I++) {
- Value *Elt = Builder.CreateExtractElement(X, I);
- Sum = Builder.CreateFAdd(Sum, Elt);
- }
- return Sum;
-}
-static Value *expandVecReduceAdd(CallInst *Orig) {
- Value *X = Orig->getOperand(0);
- IRBuilder<> Builder(Orig);
- Type *Ty = X->getType();
- auto *XVec = dyn_cast<FixedVectorType>(Ty);
- unsigned XVecSize = XVec->getNumElements();
+ // Handle the initial start value for floating-point addition.
+ if (IsFAdd) {
+ llvm::Constant *StartValue =
+ llvm::dyn_cast<llvm::Constant>(Orig->getOperand(0));
+ if (StartValue && !StartValue->isZeroValue())
+ Sum = Builder.CreateFAdd(Sum, StartValue);
+ }
- Value *Sum = Builder.CreateExtractElement(X, static_cast<uint64_t>(0));
+ // Accumulate the remaining vector elements.
for (unsigned I = 1; I < XVecSize; I++) {
Value *Elt = Builder.CreateExtractElement(X, I);
- Sum = Builder.CreateAdd(Sum, Elt);
+ Sum = AddOp(Sum, Elt);
}
+
return Sum;
}
@@ -614,10 +617,8 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
Result = expandRadiansIntrinsic(Orig);
break;
case Intrinsic::vector_reduce_add:
- Result = expandVecReduceAdd(Orig);
- break;
case Intrinsic::vector_reduce_fadd:
- Result = expandVecReduceFAdd(Orig);
+ Result = expandVecReduceAdd(Orig, IntrinsicId);
break;
}
if (Result) {
diff --git a/llvm/test/CodeGen/DirectX/vector_reduce_add.ll b/llvm/test/CodeGen/DirectX/vector_reduce_add.ll
index 96d85534f2cb05..d4ee16a24cb45f 100644
--- a/llvm/test/CodeGen/DirectX/vector_reduce_add.ll
+++ b/llvm/test/CodeGen/DirectX/vector_reduce_add.ll
@@ -17,6 +17,21 @@ entry:
ret half %rdx.fadd
}
+define noundef half @test_length_half2_start1(<2 x half> noundef %p0) {
+; CHECK-LABEL: define noundef half @test_length_half2_start1(
+; CHECK-SAME: <2 x half> noundef [[P0:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x half> [[P0]], i64 0
+; CHECK-NEXT: [[TMP1:%.*]] = fadd half [[TMP0]], 0xH0001
+; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x half> [[P0]], i64 1
+; CHECK-NEXT: [[TMP3:%.*]] = fadd half [[TMP1]], [[TMP2]]
+; CHECK-NEXT: ret half [[TMP3]]
+;
+entry:
+ %rdx.fadd = call half @llvm.vector.reduce.fadd.v2f16(half 0xH0001, <2 x half> %p0)
+ ret half %rdx.fadd
+}
+
define noundef half @test_length_half3(<3 x half> noundef %p0) {
; CHECK-LABEL: define noundef half @test_length_half3(
; CHECK-SAME: <3 x half> noundef [[P0:%.*]]) {
@@ -81,6 +96,23 @@ entry:
ret float %rdx.fadd
}
+define noundef float @test_length_float3_start1(<3 x float> noundef %p0) {
+; CHECK-LABEL: define noundef float @test_length_float3_start1(
+; CHECK-SAME: <3 x float> noundef [[P0:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = extractelement <3 x float> [[P0]], i64 0
+; CHECK-NEXT: [[TMP1:%.*]] = fadd float [[TMP0]], 1.000000e+00
+; CHECK-NEXT: [[TMP2:%.*]] = extractelement <3 x float> [[P0]], i64 1
+; CHECK-NEXT: [[TMP3:%.*]] = fadd float [[TMP1]], [[TMP2]]
+; CHECK-NEXT: [[TMP4:%.*]] = extractelement <3 x float> [[P0]], i64 2
+; CHECK-NEXT: [[TMP5:%.*]] = fadd float [[TMP3]], [[TMP4]]
+; CHECK-NEXT: ret float [[TMP5]]
+;
+entry:
+ %rdx.fadd = call float @llvm.vector.reduce.fadd.v3f32(float 1.000000e+00, <3 x float> %p0)
+ ret float %rdx.fadd
+}
+
define noundef float @test_length_float4(<4 x float> noundef %p0) {
; CHECK-LABEL: define noundef float @test_length_float4(
; CHECK-SAME: <4 x float> noundef [[P0:%.*]]) {
@@ -147,6 +179,25 @@ entry:
ret double %rdx.fadd
}
+define noundef double @test_length_double4_start1(<4 x double> noundef %p0) {
+; CHECK-LABEL: define noundef double @test_length_double4_start1(
+; CHECK-SAME: <4 x double> noundef [[P0:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[TMP0:%.*]] = extractelement <4 x double> [[P0]], i64 0
+; CHECK-NEXT: [[TMP1:%.*]] = fadd double [[TMP0]], 1.000000e+00
+; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x double> [[P0]], i64 1
+; CHECK-NEXT: [[TMP3:%.*]] = fadd double [[TMP1]], [[TMP2]]
+; CHECK-NEXT: [[TMP4:%.*]] = extractelement <4 x double> [[P0]], i64 2
+; CHECK-NEXT: [[TMP5:%.*]] = fadd double [[TMP3]], [[TMP4]]
+; CHECK-NEXT: [[TMP6:%.*]] = extractelement <4 x double> [[P0]], i64 3
+; CHECK-NEXT: [[TMP7:%.*]] = fadd double [[TMP5]], [[TMP6]]
+; CHECK-NEXT: ret double [[TMP7]]
+;
+entry:
+ %rdx.fadd = call double @llvm.vector.reduce.fadd.v4f64(double 1.000000e+00, <4 x double> %p0)
+ ret double %rdx.fadd
+}
+
define noundef i16 @test_length_short2(<2 x i16> noundef %p0) {
; CHECK-LABEL: define noundef i16 @test_length_short2(
; CHECK-SAME: <2 x i16> noundef [[P0:%.*]]) {
More information about the llvm-commits
mailing list