[clang] [llvm] [HLSL] Re-implement countbits with the correct return type (PR #113189)
Sarah Spall via cfe-commits
cfe-commits at lists.llvm.org
Fri Oct 25 14:52:48 PDT 2024
https://github.com/spall updated https://github.com/llvm/llvm-project/pull/113189
>From 12cac48dcc10ef9c5fccba2c22911f420298b98b Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Wed, 16 Oct 2024 19:00:08 +0000
Subject: [PATCH 1/4] implement countbits correctly
---
clang/lib/Headers/hlsl/hlsl_intrinsics.h | 126 +++++++++++-------
.../test/CodeGenHLSL/builtins/countbits.hlsl | 42 +++---
.../SemaHLSL/BuiltIns/countbits-errors.hlsl | 14 +-
llvm/lib/Target/DirectX/DXIL.td | 5 +-
llvm/lib/Target/DirectX/DXILOpLowering.cpp | 64 +++++++++
llvm/test/CodeGen/DirectX/countbits.ll | 39 ++++--
6 files changed, 202 insertions(+), 88 deletions(-)
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 30dce60b3ff702..2a612c3746076c 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -705,66 +705,90 @@ float4 cosh(float4);
#ifdef __HLSL_ENABLE_16_BIT
_HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int16_t countbits(int16_t);
+constexpr uint countbits(int16_t x) {
+ return __builtin_elementwise_popcount(x);
+}
_HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int16_t2 countbits(int16_t2);
+constexpr uint2 countbits(int16_t2 x) {
+ return __builtin_elementwise_popcount(x);
+}
_HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int16_t3 countbits(int16_t3);
+constexpr uint3 countbits(int16_t3 x) {
+ return __builtin_elementwise_popcount(x);
+}
_HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int16_t4 countbits(int16_t4);
+constexpr uint4 countbits(int16_t4 x) {
+ return __builtin_elementwise_popcount(x);
+}
_HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint16_t countbits(uint16_t);
+constexpr uint countbits(uint16_t x) {
+ return __builtin_elementwise_popcount(x);
+}
_HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint16_t2 countbits(uint16_t2);
+constexpr uint2 countbits(uint16_t2 x) {
+ return __builtin_elementwise_popcount(x);
+}
_HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint16_t3 countbits(uint16_t3);
+constexpr uint3 countbits(uint16_t3 x) {
+ return __builtin_elementwise_popcount(x);
+}
_HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint16_t4 countbits(uint16_t4);
+constexpr uint4 countbits(uint16_t4 x) {
+ return __builtin_elementwise_popcount(x);
+}
#endif
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int countbits(int);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int2 countbits(int2);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int3 countbits(int3);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int4 countbits(int4);
-
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint countbits(uint);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint2 countbits(uint2);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint3 countbits(uint3);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint4 countbits(uint4);
-
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int64_t countbits(int64_t);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int64_t2 countbits(int64_t2);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int64_t3 countbits(int64_t3);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int64_t4 countbits(int64_t4);
-
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint64_t countbits(uint64_t);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint64_t2 countbits(uint64_t2);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint64_t3 countbits(uint64_t3);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint64_t4 countbits(uint64_t4);
+constexpr uint countbits(int x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint2 countbits(int2 x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint3 countbits(int3 x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint4 countbits(int4 x) {
+ return __builtin_elementwise_popcount(x);
+}
+
+constexpr uint countbits(uint x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint2 countbits(uint2 x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint3 countbits(uint3 x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint4 countbits(uint4 x) {
+ return __builtin_elementwise_popcount(x);
+}
+
+constexpr uint countbits(int64_t x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint2 countbits(int64_t2 x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint3 countbits(int64_t3 x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint4 countbits(int64_t4 x) {
+ return __builtin_elementwise_popcount(x);
+}
+
+constexpr uint countbits(uint64_t x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint2 countbits(uint64_t2 x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint3 countbits(uint64_t3 x) {
+ return __builtin_elementwise_popcount(x);
+}
+constexpr uint4 countbits(uint64_t4 x) {
+ return __builtin_elementwise_popcount(x);
+}
//===----------------------------------------------------------------------===//
// degrees builtins
diff --git a/clang/test/CodeGenHLSL/builtins/countbits.hlsl b/clang/test/CodeGenHLSL/builtins/countbits.hlsl
index 8dfe977bfae626..aa9ef40d7a0dc8 100644
--- a/clang/test/CodeGenHLSL/builtins/countbits.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/countbits.hlsl
@@ -4,26 +4,30 @@
#ifdef __HLSL_ENABLE_16_BIT
// CHECK-LABEL: test_countbits_ushort
-// CHECK: call i16 @llvm.ctpop.i16
-uint16_t test_countbits_ushort(uint16_t p0)
+// CHECK: [[A:%.*]] = call i16 @llvm.ctpop.i16
+// CHECK-NEXT: zext i16 [[A]] to i32
+uint test_countbits_ushort(uint16_t p0)
{
return countbits(p0);
}
// CHECK-LABEL: test_countbits_ushort2
-// CHECK: call <2 x i16> @llvm.ctpop.v2i16
-uint16_t2 test_countbits_ushort2(uint16_t2 p0)
+// CHECK: [[A:%.*]] = call <2 x i16> @llvm.ctpop.v2i16
+// CHECK-NEXT: zext <2 x i16> [[A]] to <2 x i32>
+uint2 test_countbits_ushort2(uint16_t2 p0)
{
return countbits(p0);
}
// CHECK-LABEL: test_countbits_ushort3
-// CHECK: call <3 x i16> @llvm.ctpop.v3i16
-uint16_t3 test_countbits_ushort3(uint16_t3 p0)
+// CHECK: [[A:%.*]] = call <3 x i16> @llvm.ctpop.v3i16
+// CHECK-NEXT: zext <3 x i16> [[A]] to <3 x i32>
+uint3 test_countbits_ushort3(uint16_t3 p0)
{
return countbits(p0);
}
// CHECK-LABEL: test_countbits_ushort4
-// CHECK: call <4 x i16> @llvm.ctpop.v4i16
-uint16_t4 test_countbits_ushort4(uint16_t4 p0)
+// CHECK: [[A:%.*]] = call <4 x i16> @llvm.ctpop.v4i16
+// CHECK-NEXT: zext <4 x i16> [[A]] to <4 x i32>
+uint4 test_countbits_ushort4(uint16_t4 p0)
{
return countbits(p0);
}
@@ -31,7 +35,7 @@ uint16_t4 test_countbits_ushort4(uint16_t4 p0)
// CHECK-LABEL: test_countbits_uint
// CHECK: call i32 @llvm.ctpop.i32
-int test_countbits_uint(uint p0)
+uint test_countbits_uint(uint p0)
{
return countbits(p0);
}
@@ -55,26 +59,30 @@ uint4 test_countbits_uint4(uint4 p0)
}
// CHECK-LABEL: test_countbits_long
-// CHECK: call i64 @llvm.ctpop.i64
-uint64_t test_countbits_long(uint64_t p0)
+// CHECK: [[A:%.*]] = call i64 @llvm.ctpop.i64
+// CHECK-NEXT: trunc i64 [[A]] to i32
+uint test_countbits_long(uint64_t p0)
{
return countbits(p0);
}
// CHECK-LABEL: test_countbits_long2
-// CHECK: call <2 x i64> @llvm.ctpop.v2i64
-uint64_t2 test_countbits_long2(uint64_t2 p0)
+// CHECK: [[A:%.*]] = call <2 x i64> @llvm.ctpop.v2i64
+// CHECK-NEXT: trunc <2 x i64> [[A]] to <2 x i32>
+uint2 test_countbits_long2(uint64_t2 p0)
{
return countbits(p0);
}
// CHECK-LABEL: test_countbits_long3
-// CHECK: call <3 x i64> @llvm.ctpop.v3i64
-uint64_t3 test_countbits_long3(uint64_t3 p0)
+// CHECK: [[A:%.*]] = call <3 x i64> @llvm.ctpop.v3i64
+// CHECK-NEXT: trunc <3 x i64> [[A]] to <3 x i32>
+uint3 test_countbits_long3(uint64_t3 p0)
{
return countbits(p0);
}
// CHECK-LABEL: test_countbits_long4
-// CHECK: call <4 x i64> @llvm.ctpop.v4i64
-uint64_t4 test_countbits_long4(uint64_t4 p0)
+// CHECK: [[A:%.*]] = call <4 x i64> @llvm.ctpop.v4i64
+// CHECK-NEXT: trunc <4 x i64> [[A]] to <4 x i32>
+uint4 test_countbits_long4(uint64_t4 p0)
{
return countbits(p0);
}
diff --git a/clang/test/SemaHLSL/BuiltIns/countbits-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/countbits-errors.hlsl
index 8d5f0abb2860f8..5704165e1a4505 100644
--- a/clang/test/SemaHLSL/BuiltIns/countbits-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/countbits-errors.hlsl
@@ -1,6 +1,4 @@
-// RUN: %clang_cc1 -finclude-default-header
-// -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only
-// -disable-llvm-passes -verify -verify-ignore-unexpected
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected
double test_int_builtin(double p0) {
@@ -9,13 +7,11 @@ double test_int_builtin(double p0) {
}
double2 test_int_builtin_2(double2 p0) {
- return __builtin_elementwise_popcount(p0);
- // expected-error at -1 {{1st argument must be a vector of integers
- // (was 'double2' (aka 'vector<double, 2>'))}}
+ return countbits(p0);
+ // expected-error at -1 {{call to 'countbits' is ambiguous}}
}
double test_int_builtin_3(float p0) {
- return __builtin_elementwise_popcount(p0);
- // expected-error at -1 {{1st argument must be a vector of integers
- // (was 'float')}}
+ return countbits(p0);
+ // expected-error at -1 {{call to 'countbits' is ambiguous}}
}
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 147b32b1ca9903..73636739de0659 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -553,11 +553,10 @@ def Rbits : DXILOp<30, unary> {
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}
-def CBits : DXILOp<31, unary> {
+def CBits : DXILOp<31, unaryBits> {
let Doc = "Returns the number of 1 bits in the specified value.";
- let LLVMIntrinsic = int_ctpop;
let arguments = [OverloadTy];
- let result = OverloadTy;
+ let result = Int32Ty;
let overloads =
[Overloads<DXIL1_0, [Int16Ty, Int32Ty, Int64Ty]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index c62ba8c21d6791..2eb02b32ed2443 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -460,6 +460,67 @@ class OpLowerer {
});
}
+ [[nodiscard]] bool lowerCtpopToCBits(Function &F) {
+ IRBuilder<> &IRB = OpBuilder.getIRB();
+ Type *Int32Ty = IRB.getInt32Ty();
+
+ return replaceFunction(F, [&](CallInst *CI) -> Error {
+ IRB.SetInsertPoint(CI);
+ SmallVector<Value *> Args;
+ Args.append(CI->arg_begin(), CI->arg_end());
+
+ Type *RetTy = Int32Ty;
+ Type *FRT = F.getReturnType();
+ if (FRT->isVectorTy()) {
+ VectorType *VT = cast<VectorType>(FRT);
+ RetTy = VectorType::get(RetTy, VT);
+ }
+
+ Expected<CallInst *> OpCall =
+ OpBuilder.tryCreateOp(dxil::OpCode::CBits, Args, CI->getName(), RetTy);
+ if (Error E = OpCall.takeError())
+ return E;
+
+ // If the result type is 32 bits we can do a direct replacement.
+ if (FRT->isIntOrIntVectorTy(32)) {
+ CI->replaceAllUsesWith(*OpCall);
+ CI->eraseFromParent();
+ return Error::success();
+ }
+
+ unsigned CastOp;
+ if (FRT->isIntOrIntVectorTy(16))
+ CastOp = Instruction::ZExt;
+ else // must be 64 bits
+ CastOp = Instruction::Trunc;
+
+ // It is correct to replace the ctpop with the dxil op and
+ // remove an existing cast iff the cast is the only usage of
+ // the ctpop
+ // can use hasOneUse instead of hasOneUser, because the user
+ // we care about should have one operand
+ if (CI->hasOneUse()) {
+ User *U = CI->user_back();
+ Instruction *I;
+ if (isa<Instruction>(U) && (I = cast<Instruction>(U)) &&
+ I->getOpcode() == CastOp && I->getType() == RetTy) {
+ I->replaceAllUsesWith(*OpCall);
+ I->eraseFromParent();
+ CI->eraseFromParent();
+ return Error::success();
+ }
+ }
+
+ // It is always correct to replace a ctpop with the dxil op and
+ // a cast
+ Value *Cast = IRB.CreateZExtOrTrunc(*OpCall, F.getReturnType(),
+ "ctpop.cast");
+ CI->replaceAllUsesWith(Cast);
+ CI->eraseFromParent();
+ return Error::success();
+ });
+ }
+
bool lowerIntrinsics() {
bool Updated = false;
bool HasErrors = false;
@@ -488,6 +549,9 @@ class OpLowerer {
case Intrinsic::dx_typedBufferStore:
HasErrors |= lowerTypedBufferStore(F);
break;
+ case Intrinsic::ctpop:
+ HasErrors |= lowerCtpopToCBits(F);
+ break;
}
Updated = true;
}
diff --git a/llvm/test/CodeGen/DirectX/countbits.ll b/llvm/test/CodeGen/DirectX/countbits.ll
index c6bc2b6790948e..91f6f560903f01 100644
--- a/llvm/test/CodeGen/DirectX/countbits.ll
+++ b/llvm/test/CodeGen/DirectX/countbits.ll
@@ -4,35 +4,58 @@
define noundef i16 @test_countbits_short(i16 noundef %a) {
entry:
-; CHECK: call i16 @dx.op.unary.i16(i32 31, i16 %{{.*}})
+; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i16(i32 31, i16 %{{.*}})
+; CHECK-NEXT: [[B:%.*]] = trunc i32 [[A]] to i16
+; CHECK-NEXT ret i16 [[B]]
%elt.ctpop = call i16 @llvm.ctpop.i16(i16 %a)
ret i16 %elt.ctpop
}
+define noundef i32 @test_countbits_short2(i16 noundef %a) {
+entry:
+; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i16(i32 31, i16 %{{.*}})
+; CHECK-NEXT: ret i32 [[A]]
+ %elt.ctpop = call i16 @llvm.ctpop.i16(i16 %a)
+ %elt.zext = zext i16 %elt.ctpop to i32
+ ret i32 %elt.zext
+}
+
define noundef i32 @test_countbits_int(i32 noundef %a) {
entry:
-; CHECK: call i32 @dx.op.unary.i32(i32 31, i32 %{{.*}})
+; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i32(i32 31, i32 %{{.*}})
+; CHECK-NEXT: ret i32 [[A]]
%elt.ctpop = call i32 @llvm.ctpop.i32(i32 %a)
ret i32 %elt.ctpop
}
define noundef i64 @test_countbits_long(i64 noundef %a) {
entry:
-; CHECK: call i64 @dx.op.unary.i64(i32 31, i64 %{{.*}})
+; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i64(i32 31, i64 %{{.*}})
+; CHECK-NEXT: [[B:%.*]] = zext i32 [[A]] to i64
+; CHECK-NEXT ret i64 [[B]]
%elt.ctpop = call i64 @llvm.ctpop.i64(i64 %a)
ret i64 %elt.ctpop
}
+define noundef i32 @test_countbits_long2(i64 noundef %a) {
+entry:
+; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i64(i32 31, i64 %{{.*}})
+; CHECK-NEXT: ret i32 [[A]]
+ %elt.ctpop = call i64 @llvm.ctpop.i64(i64 %a)
+ %elt.trunc = trunc i64 %elt.ctpop to i32
+ ret i32 %elt.trunc
+}
+
define noundef <4 x i32> @countbits_vec4_i32(<4 x i32> noundef %a) {
entry:
; CHECK: [[ee0:%.*]] = extractelement <4 x i32> %a, i64 0
- ; CHECK: [[ie0:%.*]] = call i32 @dx.op.unary.i32(i32 31, i32 [[ee0]])
+ ; CHECK: [[ie0:%.*]] = call i32 @dx.op.unaryBits.i32(i32 31, i32 [[ee0]])
; CHECK: [[ee1:%.*]] = extractelement <4 x i32> %a, i64 1
- ; CHECK: [[ie1:%.*]] = call i32 @dx.op.unary.i32(i32 31, i32 [[ee1]])
+ ; CHECK: [[ie1:%.*]] = call i32 @dx.op.unaryBits.i32(i32 31, i32 [[ee1]])
; CHECK: [[ee2:%.*]] = extractelement <4 x i32> %a, i64 2
- ; CHECK: [[ie2:%.*]] = call i32 @dx.op.unary.i32(i32 31, i32 [[ee2]])
+ ; CHECK: [[ie2:%.*]] = call i32 @dx.op.unaryBits.i32(i32 31, i32 [[ee2]])
; CHECK: [[ee3:%.*]] = extractelement <4 x i32> %a, i64 3
- ; CHECK: [[ie3:%.*]] = call i32 @dx.op.unary.i32(i32 31, i32 [[ee3]])
+ ; CHECK: [[ie3:%.*]] = call i32 @dx.op.unaryBits.i32(i32 31, i32 [[ee3]])
; CHECK: insertelement <4 x i32> poison, i32 [[ie0]], i64 0
; CHECK: insertelement <4 x i32> %{{.*}}, i32 [[ie1]], i64 1
; CHECK: insertelement <4 x i32> %{{.*}}, i32 [[ie2]], i64 2
@@ -44,4 +67,4 @@ entry:
declare i16 @llvm.ctpop.i16(i16)
declare i32 @llvm.ctpop.i32(i32)
declare i64 @llvm.ctpop.i64(i64)
-declare <4 x i32> @llvm.ctpop.v4i32(<4 x i32>)
+declare <4 x i32> @llvm.ctpop.v4i32(<4 x i32>)
\ No newline at end of file
>From 7b368e394aa7251271d3a548f6fb2b5158f1ad16 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Mon, 21 Oct 2024 16:31:07 +0000
Subject: [PATCH 2/4] make clang format happy
---
clang/lib/Headers/hlsl/hlsl_intrinsics.h | 34 +++++------------
llvm/lib/Target/DirectX/DXILOpLowering.cpp | 44 +++++++++++-----------
2 files changed, 31 insertions(+), 47 deletions(-)
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 2a612c3746076c..dbb55b5ecda432 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -738,31 +738,15 @@ constexpr uint4 countbits(uint16_t4 x) {
}
#endif
-constexpr uint countbits(int x) {
- return __builtin_elementwise_popcount(x);
-}
-constexpr uint2 countbits(int2 x) {
- return __builtin_elementwise_popcount(x);
-}
-constexpr uint3 countbits(int3 x) {
- return __builtin_elementwise_popcount(x);
-}
-constexpr uint4 countbits(int4 x) {
- return __builtin_elementwise_popcount(x);
-}
-
-constexpr uint countbits(uint x) {
- return __builtin_elementwise_popcount(x);
-}
-constexpr uint2 countbits(uint2 x) {
- return __builtin_elementwise_popcount(x);
-}
-constexpr uint3 countbits(uint3 x) {
- return __builtin_elementwise_popcount(x);
-}
-constexpr uint4 countbits(uint4 x) {
- return __builtin_elementwise_popcount(x);
-}
+constexpr uint countbits(int x) { return __builtin_elementwise_popcount(x); }
+constexpr uint2 countbits(int2 x) { return __builtin_elementwise_popcount(x); }
+constexpr uint3 countbits(int3 x) { return __builtin_elementwise_popcount(x); }
+constexpr uint4 countbits(int4 x) { return __builtin_elementwise_popcount(x); }
+
+constexpr uint countbits(uint x) { return __builtin_elementwise_popcount(x); }
+constexpr uint2 countbits(uint2 x) { return __builtin_elementwise_popcount(x); }
+constexpr uint3 countbits(uint3 x) { return __builtin_elementwise_popcount(x); }
+constexpr uint4 countbits(uint4 x) { return __builtin_elementwise_popcount(x); }
constexpr uint countbits(int64_t x) {
return __builtin_elementwise_popcount(x);
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 2eb02b32ed2443..a3533b8ebcc5f0 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -463,7 +463,7 @@ class OpLowerer {
[[nodiscard]] bool lowerCtpopToCBits(Function &F) {
IRBuilder<> &IRB = OpBuilder.getIRB();
Type *Int32Ty = IRB.getInt32Ty();
-
+
return replaceFunction(F, [&](CallInst *CI) -> Error {
IRB.SetInsertPoint(CI);
SmallVector<Value *> Args;
@@ -473,26 +473,26 @@ class OpLowerer {
Type *FRT = F.getReturnType();
if (FRT->isVectorTy()) {
VectorType *VT = cast<VectorType>(FRT);
- RetTy = VectorType::get(RetTy, VT);
+ RetTy = VectorType::get(RetTy, VT);
}
-
- Expected<CallInst *> OpCall =
- OpBuilder.tryCreateOp(dxil::OpCode::CBits, Args, CI->getName(), RetTy);
+
+ Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
+ dxil::OpCode::CBits, Args, CI->getName(), RetTy);
if (Error E = OpCall.takeError())
- return E;
+ return E;
// If the result type is 32 bits we can do a direct replacement.
if (FRT->isIntOrIntVectorTy(32)) {
CI->replaceAllUsesWith(*OpCall);
- CI->eraseFromParent();
- return Error::success();
+ CI->eraseFromParent();
+ return Error::success();
}
unsigned CastOp;
if (FRT->isIntOrIntVectorTy(16))
- CastOp = Instruction::ZExt;
+ CastOp = Instruction::ZExt;
else // must be 64 bits
- CastOp = Instruction::Trunc;
+ CastOp = Instruction::Trunc;
// It is correct to replace the ctpop with the dxil op and
// remove an existing cast iff the cast is the only usage of
@@ -500,21 +500,21 @@ class OpLowerer {
// can use hasOneUse instead of hasOneUser, because the user
// we care about should have one operand
if (CI->hasOneUse()) {
- User *U = CI->user_back();
- Instruction *I;
- if (isa<Instruction>(U) && (I = cast<Instruction>(U)) &&
- I->getOpcode() == CastOp && I->getType() == RetTy) {
+ User *U = CI->user_back();
+ Instruction *I;
+ if (isa<Instruction>(U) && (I = cast<Instruction>(U)) &&
+ I->getOpcode() == CastOp && I->getType() == RetTy) {
I->replaceAllUsesWith(*OpCall);
- I->eraseFromParent();
- CI->eraseFromParent();
- return Error::success();
- }
+ I->eraseFromParent();
+ CI->eraseFromParent();
+ return Error::success();
+ }
}
// It is always correct to replace a ctpop with the dxil op and
// a cast
- Value *Cast = IRB.CreateZExtOrTrunc(*OpCall, F.getReturnType(),
- "ctpop.cast");
+ Value *Cast =
+ IRB.CreateZExtOrTrunc(*OpCall, F.getReturnType(), "ctpop.cast");
CI->replaceAllUsesWith(Cast);
CI->eraseFromParent();
return Error::success();
@@ -550,8 +550,8 @@ class OpLowerer {
HasErrors |= lowerTypedBufferStore(F);
break;
case Intrinsic::ctpop:
- HasErrors |= lowerCtpopToCBits(F);
- break;
+ HasErrors |= lowerCtpopToCBits(F);
+ break;
}
Updated = true;
}
>From c9b4453239d95e8a232eea4597ba5257584fa280 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Wed, 23 Oct 2024 16:31:04 +0000
Subject: [PATCH 3/4] address PR comments
---
llvm/lib/Target/DirectX/DXIL.td | 2 +-
llvm/lib/Target/DirectX/DXILOpLowering.cpp | 47 ++++++++++++----------
llvm/test/CodeGen/DirectX/countbits.ll | 2 +-
3 files changed, 27 insertions(+), 24 deletions(-)
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 73636739de0659..c0094ec9b05553 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -553,7 +553,7 @@ def Rbits : DXILOp<30, unary> {
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}
-def CBits : DXILOp<31, unaryBits> {
+def CountBits : DXILOp<31, unaryBits> {
let Doc = "Returns the number of 1 bits in the specified value.";
let arguments = [OverloadTy];
let result = Int32Ty;
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index a3533b8ebcc5f0..629ecd35b81c21 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -460,7 +460,7 @@ class OpLowerer {
});
}
- [[nodiscard]] bool lowerCtpopToCBits(Function &F) {
+ [[nodiscard]] bool lowerCtpopToCountBits(Function &F) {
IRBuilder<> &IRB = OpBuilder.getIRB();
Type *Int32Ty = IRB.getInt32Ty();
@@ -471,13 +471,11 @@ class OpLowerer {
Type *RetTy = Int32Ty;
Type *FRT = F.getReturnType();
- if (FRT->isVectorTy()) {
- VectorType *VT = cast<VectorType>(FRT);
+ if (const auto *VT = dyn_cast<VectorType>(FRT))
RetTy = VectorType::get(RetTy, VT);
- }
Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
- dxil::OpCode::CBits, Args, CI->getName(), RetTy);
+ dxil::OpCode::CountBits, Args, CI->getName(), RetTy);
if (Error E = OpCall.takeError())
return E;
@@ -491,31 +489,36 @@ class OpLowerer {
unsigned CastOp;
if (FRT->isIntOrIntVectorTy(16))
CastOp = Instruction::ZExt;
- else // must be 64 bits
+ else { // must be 64 bits
+ assert(FRT->isIntOrIntVectorTy(64) &&
+ "Currently only lowering 16, 32, or 64 bit ctpop to CountBits \
+ is supported.");
CastOp = Instruction::Trunc;
+ }
// It is correct to replace the ctpop with the dxil op and
- // remove an existing cast iff the cast is the only usage of
- // the ctpop
- // can use hasOneUse instead of hasOneUser, because the user
- // we care about should have one operand
- if (CI->hasOneUse()) {
- User *U = CI->user_back();
+ // remove all casts to i32
+ bool nonCastInstr = false;
+ for (User *User : make_early_inc_range(CI->users())) {
Instruction *I;
- if (isa<Instruction>(U) && (I = cast<Instruction>(U)) &&
+ if ((I = dyn_cast<Instruction>(User)) != NULL &&
I->getOpcode() == CastOp && I->getType() == RetTy) {
I->replaceAllUsesWith(*OpCall);
I->eraseFromParent();
- CI->eraseFromParent();
- return Error::success();
- }
+ } else
+ nonCastInstr = true;
+ }
+
+ // It is correct to replace a ctpop with the dxil op and
+ // a cast from i32 to the return type of the ctpop
+ // the cast is emitted here if there is a non-cast to i32
+ // instr which uses the ctpop
+ if (nonCastInstr) {
+ Value *Cast =
+ IRB.CreateZExtOrTrunc(*OpCall, F.getReturnType(), "ctpop.cast");
+ CI->replaceAllUsesWith(Cast);
}
- // It is always correct to replace a ctpop with the dxil op and
- // a cast
- Value *Cast =
- IRB.CreateZExtOrTrunc(*OpCall, F.getReturnType(), "ctpop.cast");
- CI->replaceAllUsesWith(Cast);
CI->eraseFromParent();
return Error::success();
});
@@ -550,7 +553,7 @@ class OpLowerer {
HasErrors |= lowerTypedBufferStore(F);
break;
case Intrinsic::ctpop:
- HasErrors |= lowerCtpopToCBits(F);
+ HasErrors |= lowerCtpopToCountBits(F);
break;
}
Updated = true;
diff --git a/llvm/test/CodeGen/DirectX/countbits.ll b/llvm/test/CodeGen/DirectX/countbits.ll
index 91f6f560903f01..11675d67ed0cba 100644
--- a/llvm/test/CodeGen/DirectX/countbits.ll
+++ b/llvm/test/CodeGen/DirectX/countbits.ll
@@ -67,4 +67,4 @@ entry:
declare i16 @llvm.ctpop.i16(i16)
declare i32 @llvm.ctpop.i32(i32)
declare i64 @llvm.ctpop.i64(i64)
-declare <4 x i32> @llvm.ctpop.v4i32(<4 x i32>)
\ No newline at end of file
+declare <4 x i32> @llvm.ctpop.v4i32(<4 x i32>)
>From 3cd3a0fc063e4ae19d43261deaf6d1f3789e0217 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Thu, 24 Oct 2024 00:34:22 +0000
Subject: [PATCH 4/4] address latest pr comments + extra tests
---
.../test/CodeGenHLSL/builtins/countbits.hlsl | 20 +++++++++++++++++++
llvm/lib/Target/DirectX/DXILOpLowering.cpp | 19 ++++++++++--------
llvm/test/CodeGen/DirectX/countbits.ll | 9 +++++++++
3 files changed, 40 insertions(+), 8 deletions(-)
diff --git a/clang/test/CodeGenHLSL/builtins/countbits.hlsl b/clang/test/CodeGenHLSL/builtins/countbits.hlsl
index aa9ef40d7a0dc8..218d8dcd10f8d7 100644
--- a/clang/test/CodeGenHLSL/builtins/countbits.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/countbits.hlsl
@@ -10,6 +10,13 @@ uint test_countbits_ushort(uint16_t p0)
{
return countbits(p0);
}
+// CHECK-LABEL: test_countbits_short
+// CHECK: [[A:%.*]] = call i16 @llvm.ctpop.i16
+// CHECK-NEXT: sext i16 [[A]] to i32
+uint test_countbits_short(int16_t p0)
+{
+ return countbits(p0);
+}
// CHECK-LABEL: test_countbits_ushort2
// CHECK: [[A:%.*]] = call <2 x i16> @llvm.ctpop.v2i16
// CHECK-NEXT: zext <2 x i16> [[A]] to <2 x i32>
@@ -39,6 +46,12 @@ uint test_countbits_uint(uint p0)
{
return countbits(p0);
}
+// CHECK-LABEL: test_countbits_int
+// CHECK: call i32 @llvm.ctpop.i32
+uint test_countbits_int(int p0)
+{
+ return countbits(p0);
+}
// CHECK-LABEL: test_countbits_uint2
// CHECK: call <2 x i32> @llvm.ctpop.v2i32
uint2 test_countbits_uint2(uint2 p0)
@@ -65,6 +78,13 @@ uint test_countbits_long(uint64_t p0)
{
return countbits(p0);
}
+// CHECK-LABEL: test_countbits_slong
+// CHECK: [[A:%.*]] = call i64 @llvm.ctpop.i64
+// CHECK-NEXT: trunc i64 [[A]] to i32
+uint test_countbits_slong(int64_t p0)
+{
+ return countbits(p0);
+}
// CHECK-LABEL: test_countbits_long2
// CHECK: [[A:%.*]] = call <2 x i64> @llvm.ctpop.v2i64
// CHECK-NEXT: trunc <2 x i64> [[A]] to <2 x i32>
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 629ecd35b81c21..98bcbf2a363ce0 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -487,33 +487,36 @@ class OpLowerer {
}
unsigned CastOp;
- if (FRT->isIntOrIntVectorTy(16))
+ unsigned CastOp2;
+ if (FRT->isIntOrIntVectorTy(16)) {
CastOp = Instruction::ZExt;
- else { // must be 64 bits
+ CastOp2 = Instruction::SExt;
+ } else { // must be 64 bits
assert(FRT->isIntOrIntVectorTy(64) &&
"Currently only lowering 16, 32, or 64 bit ctpop to CountBits \
is supported.");
CastOp = Instruction::Trunc;
+ CastOp2 = Instruction::Trunc;
}
// It is correct to replace the ctpop with the dxil op and
// remove all casts to i32
- bool nonCastInstr = false;
+ bool NeedsCast = false;
for (User *User : make_early_inc_range(CI->users())) {
- Instruction *I;
- if ((I = dyn_cast<Instruction>(User)) != NULL &&
- I->getOpcode() == CastOp && I->getType() == RetTy) {
+ Instruction *I = dyn_cast<Instruction>(User);
+ if (I && (I->getOpcode() == CastOp || I->getOpcode() == CastOp2) &&
+ I->getType() == RetTy) {
I->replaceAllUsesWith(*OpCall);
I->eraseFromParent();
} else
- nonCastInstr = true;
+ NeedsCast = true;
}
// It is correct to replace a ctpop with the dxil op and
// a cast from i32 to the return type of the ctpop
// the cast is emitted here if there is a non-cast to i32
// instr which uses the ctpop
- if (nonCastInstr) {
+ if (NeedsCast) {
Value *Cast =
IRB.CreateZExtOrTrunc(*OpCall, F.getReturnType(), "ctpop.cast");
CI->replaceAllUsesWith(Cast);
diff --git a/llvm/test/CodeGen/DirectX/countbits.ll b/llvm/test/CodeGen/DirectX/countbits.ll
index 11675d67ed0cba..f03ab9c5e79c35 100644
--- a/llvm/test/CodeGen/DirectX/countbits.ll
+++ b/llvm/test/CodeGen/DirectX/countbits.ll
@@ -20,6 +20,15 @@ entry:
ret i32 %elt.zext
}
+define noundef i32 @test_countbits_short3(i16 noundef %a) {
+entry:
+; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i16(i32 31, i16 %{{.*}})
+; CHECK-NEXT: ret i32 [[A]]
+ %elt.ctpop = call i16 @llvm.ctpop.i16(i16 %a)
+ %elt.sext = sext i16 %elt.ctpop to i32
+ ret i32 %elt.sext
+}
+
define noundef i32 @test_countbits_int(i32 noundef %a) {
entry:
; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i32(i32 31, i32 %{{.*}})
More information about the cfe-commits
mailing list