[clang] [llvm] [HLSL] Re-implement countbits with the correct return type (PR #113189)

Sarah Spall via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 28 16:39:24 PDT 2024


https://github.com/spall updated https://github.com/llvm/llvm-project/pull/113189

>From 23d62026c8338e6ad92495cfcaa54ff1fa5d08f0 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Wed, 16 Oct 2024 19:00:08 +0000
Subject: [PATCH 1/5] implement countbits correctly

---
 clang/lib/Headers/hlsl/hlsl_intrinsics.h      | 126 +++++++++++-------
 .../test/CodeGenHLSL/builtins/countbits.hlsl  |  42 +++---
 .../SemaHLSL/BuiltIns/countbits-errors.hlsl   |  14 +-
 llvm/lib/Target/DirectX/DXIL.td               |   5 +-
 llvm/lib/Target/DirectX/DXILOpLowering.cpp    |  64 +++++++++
 llvm/test/CodeGen/DirectX/countbits.ll        |  39 ++++--
 6 files changed, 202 insertions(+), 88 deletions(-)

diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 8ade4b27f360fb..7936506a0461b6 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -723,66 +723,90 @@ float4 cosh(float4);
 
 #ifdef __HLSL_ENABLE_16_BIT
 _HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int16_t countbits(int16_t);
+constexpr uint countbits(int16_t x) {
+  return __builtin_elementwise_popcount(x);
+}
 _HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int16_t2 countbits(int16_t2);
+constexpr uint2 countbits(int16_t2 x) {
+  return __builtin_elementwise_popcount(x);
+}
 _HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int16_t3 countbits(int16_t3);
+constexpr uint3 countbits(int16_t3 x) {
+  return __builtin_elementwise_popcount(x);
+}
 _HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int16_t4 countbits(int16_t4);
+constexpr uint4 countbits(int16_t4 x) {
+  return __builtin_elementwise_popcount(x);
+}
 _HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint16_t countbits(uint16_t);
+constexpr uint countbits(uint16_t x) {
+  return __builtin_elementwise_popcount(x);
+}
 _HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint16_t2 countbits(uint16_t2);
+constexpr uint2 countbits(uint16_t2 x) {
+  return __builtin_elementwise_popcount(x);
+}
 _HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint16_t3 countbits(uint16_t3);
+constexpr uint3 countbits(uint16_t3 x) {
+  return __builtin_elementwise_popcount(x);
+}
 _HLSL_AVAILABILITY(shadermodel, 6.2)
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint16_t4 countbits(uint16_t4);
+constexpr uint4 countbits(uint16_t4 x) {
+  return __builtin_elementwise_popcount(x);
+}
 #endif
 
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int countbits(int);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int2 countbits(int2);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int3 countbits(int3);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int4 countbits(int4);
-
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint countbits(uint);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint2 countbits(uint2);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint3 countbits(uint3);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint4 countbits(uint4);
-
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int64_t countbits(int64_t);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int64_t2 countbits(int64_t2);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int64_t3 countbits(int64_t3);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-int64_t4 countbits(int64_t4);
-
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint64_t countbits(uint64_t);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint64_t2 countbits(uint64_t2);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint64_t3 countbits(uint64_t3);
-_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
-uint64_t4 countbits(uint64_t4);
+constexpr uint countbits(int x) {
+  return __builtin_elementwise_popcount(x);
+}  
+constexpr uint2 countbits(int2 x) {
+  return __builtin_elementwise_popcount(x);
+}
+constexpr uint3 countbits(int3 x) {
+  return __builtin_elementwise_popcount(x);
+}
+constexpr uint4 countbits(int4 x) {
+  return __builtin_elementwise_popcount(x);
+}
+
+constexpr uint countbits(uint x) {
+  return __builtin_elementwise_popcount(x);
+}
+constexpr uint2 countbits(uint2 x) {
+  return __builtin_elementwise_popcount(x);
+}
+constexpr uint3 countbits(uint3 x) {
+  return __builtin_elementwise_popcount(x);
+}
+constexpr uint4 countbits(uint4 x) {
+  return __builtin_elementwise_popcount(x);
+}
+
+constexpr uint countbits(int64_t x) {
+  return __builtin_elementwise_popcount(x);
+}
+constexpr uint2 countbits(int64_t2 x) {
+  return __builtin_elementwise_popcount(x);
+}
+constexpr uint3 countbits(int64_t3 x) {
+  return __builtin_elementwise_popcount(x);
+}
+constexpr uint4 countbits(int64_t4 x) {
+  return __builtin_elementwise_popcount(x);
+}
+
+constexpr uint countbits(uint64_t x) {
+  return __builtin_elementwise_popcount(x);
+}
+constexpr uint2 countbits(uint64_t2 x) {
+  return __builtin_elementwise_popcount(x);
+}
+constexpr uint3 countbits(uint64_t3 x) {
+  return __builtin_elementwise_popcount(x);
+}
+constexpr uint4 countbits(uint64_t4 x) {
+  return __builtin_elementwise_popcount(x);
+}
 
 //===----------------------------------------------------------------------===//
 // degrees builtins
diff --git a/clang/test/CodeGenHLSL/builtins/countbits.hlsl b/clang/test/CodeGenHLSL/builtins/countbits.hlsl
index 8dfe977bfae626..aa9ef40d7a0dc8 100644
--- a/clang/test/CodeGenHLSL/builtins/countbits.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/countbits.hlsl
@@ -4,26 +4,30 @@
 
 #ifdef __HLSL_ENABLE_16_BIT
 // CHECK-LABEL: test_countbits_ushort
-// CHECK: call i16 @llvm.ctpop.i16
-uint16_t test_countbits_ushort(uint16_t p0)
+// CHECK: [[A:%.*]] = call i16 @llvm.ctpop.i16
+// CHECK-NEXT: zext i16 [[A]] to i32
+uint test_countbits_ushort(uint16_t p0)
 {
 	return countbits(p0);
 }
 // CHECK-LABEL: test_countbits_ushort2
-// CHECK: call <2 x i16> @llvm.ctpop.v2i16
-uint16_t2 test_countbits_ushort2(uint16_t2 p0)
+// CHECK: [[A:%.*]] = call <2 x i16> @llvm.ctpop.v2i16
+// CHECK-NEXT: zext <2 x i16> [[A]] to <2 x i32>
+uint2 test_countbits_ushort2(uint16_t2 p0)
 {
 	return countbits(p0);
 }
 // CHECK-LABEL: test_countbits_ushort3
-// CHECK: call <3 x i16> @llvm.ctpop.v3i16
-uint16_t3 test_countbits_ushort3(uint16_t3 p0)
+// CHECK: [[A:%.*]] = call <3 x i16> @llvm.ctpop.v3i16
+// CHECK-NEXT: zext <3 x i16> [[A]] to <3 x i32>
+uint3 test_countbits_ushort3(uint16_t3 p0)
 {
 	return countbits(p0);
 }
 // CHECK-LABEL: test_countbits_ushort4
-// CHECK: call <4 x i16> @llvm.ctpop.v4i16
-uint16_t4 test_countbits_ushort4(uint16_t4 p0)
+// CHECK: [[A:%.*]] = call <4 x i16> @llvm.ctpop.v4i16
+// CHECK-NEXT: zext <4 x i16> [[A]] to <4 x i32>
+uint4 test_countbits_ushort4(uint16_t4 p0)
 {
 	return countbits(p0);
 }
@@ -31,7 +35,7 @@ uint16_t4 test_countbits_ushort4(uint16_t4 p0)
 
 // CHECK-LABEL: test_countbits_uint
 // CHECK: call i32 @llvm.ctpop.i32
-int test_countbits_uint(uint p0)
+uint test_countbits_uint(uint p0)
 {
 	return countbits(p0);
 }
@@ -55,26 +59,30 @@ uint4 test_countbits_uint4(uint4 p0)
 }
 
 // CHECK-LABEL: test_countbits_long
-// CHECK: call i64 @llvm.ctpop.i64
-uint64_t test_countbits_long(uint64_t p0)
+// CHECK: [[A:%.*]] = call i64 @llvm.ctpop.i64
+// CHECK-NEXT: trunc i64 [[A]] to i32
+uint test_countbits_long(uint64_t p0)
 {
 	return countbits(p0);
 }
 // CHECK-LABEL: test_countbits_long2
-// CHECK: call <2 x i64> @llvm.ctpop.v2i64
-uint64_t2 test_countbits_long2(uint64_t2 p0)
+// CHECK: [[A:%.*]] = call <2 x i64> @llvm.ctpop.v2i64
+// CHECK-NEXT: trunc <2 x i64> [[A]] to <2 x i32>
+uint2 test_countbits_long2(uint64_t2 p0)
 {
 	return countbits(p0);
 }
 // CHECK-LABEL: test_countbits_long3
-// CHECK: call <3 x i64> @llvm.ctpop.v3i64
-uint64_t3 test_countbits_long3(uint64_t3 p0)
+// CHECK: [[A:%.*]] = call <3 x i64> @llvm.ctpop.v3i64
+// CHECK-NEXT: trunc <3 x i64> [[A]] to <3 x i32>
+uint3 test_countbits_long3(uint64_t3 p0)
 {
 	return countbits(p0);
 }
 // CHECK-LABEL: test_countbits_long4
-// CHECK: call <4 x i64> @llvm.ctpop.v4i64
-uint64_t4 test_countbits_long4(uint64_t4 p0)
+// CHECK: [[A:%.*]] = call <4 x i64> @llvm.ctpop.v4i64
+// CHECK-NEXT: trunc <4 x i64> [[A]] to <4 x i32>
+uint4 test_countbits_long4(uint64_t4 p0)
 {
 	return countbits(p0);
 }
diff --git a/clang/test/SemaHLSL/BuiltIns/countbits-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/countbits-errors.hlsl
index 8d5f0abb2860f8..5704165e1a4505 100644
--- a/clang/test/SemaHLSL/BuiltIns/countbits-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/countbits-errors.hlsl
@@ -1,6 +1,4 @@
-// RUN: %clang_cc1 -finclude-default-header
-// -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only
-// -disable-llvm-passes -verify -verify-ignore-unexpected
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected
 
 
 double test_int_builtin(double p0) {
@@ -9,13 +7,11 @@ double test_int_builtin(double p0) {
 }
 
 double2 test_int_builtin_2(double2 p0) {
-  return __builtin_elementwise_popcount(p0);
-  // expected-error at -1 {{1st argument must be a vector of integers
-  // (was 'double2' (aka 'vector<double, 2>'))}}
+  return countbits(p0);
+  // expected-error at -1 {{call to 'countbits' is ambiguous}}
 }
 
 double test_int_builtin_3(float p0) {
-  return __builtin_elementwise_popcount(p0);
-  // expected-error at -1 {{1st argument must be a vector of integers
-  // (was 'float')}}
+  return countbits(p0);
+  // expected-error at -1 {{call to 'countbits' is ambiguous}}
 }
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 68ae5de06423c2..1e51a04f4997b1 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -554,11 +554,10 @@ def Rbits :  DXILOp<30, unary> {
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
 }
 
-def CBits :  DXILOp<31, unary> {
+def CBits :  DXILOp<31, unaryBits> {
   let Doc = "Returns the number of 1 bits in the specified value.";
-  let LLVMIntrinsic = int_ctpop;
   let arguments = [OverloadTy];
-  let result = OverloadTy;
+  let result = Int32Ty;
   let overloads =
       [Overloads<DXIL1_0, [Int16Ty, Int32Ty, Int64Ty]>];
   let stages = [Stages<DXIL1_0, [all_stages]>];
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index f7722d77074764..646bced9460649 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -505,6 +505,67 @@ class OpLowerer {
     });
   }
 
+  [[nodiscard]] bool lowerCtpopToCBits(Function &F) {
+    IRBuilder<> &IRB = OpBuilder.getIRB();
+    Type *Int32Ty = IRB.getInt32Ty();
+    
+    return replaceFunction(F, [&](CallInst *CI) -> Error {
+      IRB.SetInsertPoint(CI);
+      SmallVector<Value *> Args;
+      Args.append(CI->arg_begin(), CI->arg_end());
+
+      Type *RetTy = Int32Ty;
+      Type *FRT = F.getReturnType();
+      if (FRT->isVectorTy()) {
+        VectorType *VT = cast<VectorType>(FRT);
+	RetTy = VectorType::get(RetTy, VT);
+      }
+      
+      Expected<CallInst *> OpCall =
+	OpBuilder.tryCreateOp(dxil::OpCode::CBits, Args, CI->getName(), RetTy);
+      if (Error E = OpCall.takeError())
+	return E;
+
+      // If the result type is 32 bits we can do a direct replacement.
+      if (FRT->isIntOrIntVectorTy(32)) {
+        CI->replaceAllUsesWith(*OpCall);
+	CI->eraseFromParent();
+	return Error::success();
+      }
+
+      unsigned CastOp;
+      if (FRT->isIntOrIntVectorTy(16))
+	CastOp = Instruction::ZExt;
+      else // must be 64 bits
+	CastOp = Instruction::Trunc;
+
+      // It is correct to replace the ctpop with the dxil op and
+      // remove an existing cast iff the cast is the only usage of
+      // the ctpop
+      // can use hasOneUse instead of hasOneUser, because the user
+      // we care about should have one operand
+      if (CI->hasOneUse()) {
+	User *U = CI->user_back();
+	Instruction *I;
+	if (isa<Instruction>(U) && (I = cast<Instruction>(U)) &&
+	    I->getOpcode() == CastOp && I->getType() == RetTy) {
+          I->replaceAllUsesWith(*OpCall);
+	  I->eraseFromParent();
+	  CI->eraseFromParent();
+	  return Error::success();
+	  }
+      }
+
+      // It is always correct to replace a ctpop with the dxil op and
+      // a cast
+      Value *Cast = IRB.CreateZExtOrTrunc(*OpCall, F.getReturnType(),
+					  "ctpop.cast");
+      CI->replaceAllUsesWith(Cast);
+      CI->eraseFromParent();
+      return Error::success();
+    });
+  }
+
   bool lowerIntrinsics() {
     bool Updated = false;
     bool HasErrors = false;
@@ -543,6 +604,9 @@ class OpLowerer {
               return replaceSplitDoubleCallUsages(CI, Op);
             });
         break;
+      case Intrinsic::ctpop:
+	HasErrors |= lowerCtpopToCBits(F);
+	break;
       }
       Updated = true;
     }
diff --git a/llvm/test/CodeGen/DirectX/countbits.ll b/llvm/test/CodeGen/DirectX/countbits.ll
index c6bc2b6790948e..91f6f560903f01 100644
--- a/llvm/test/CodeGen/DirectX/countbits.ll
+++ b/llvm/test/CodeGen/DirectX/countbits.ll
@@ -4,35 +4,58 @@
 
 define noundef i16 @test_countbits_short(i16 noundef %a) {
 entry:
-; CHECK: call i16 @dx.op.unary.i16(i32 31, i16 %{{.*}})
+; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i16(i32 31, i16 %{{.*}})
+; CHECK-NEXT: [[B:%.*]] = trunc i32 [[A]] to i16
+; CHECK-NEXT ret i16 [[B]]
   %elt.ctpop = call i16 @llvm.ctpop.i16(i16 %a)
   ret i16 %elt.ctpop
 }
 
+define noundef i32 @test_countbits_short2(i16 noundef %a) {
+entry:
+; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i16(i32 31, i16 %{{.*}})
+; CHECK-NEXT: ret i32 [[A]]
+  %elt.ctpop = call i16 @llvm.ctpop.i16(i16 %a)
+  %elt.zext = zext i16 %elt.ctpop to i32
+  ret i32 %elt.zext
+}
+
 define noundef i32 @test_countbits_int(i32 noundef %a) {
 entry:
-; CHECK: call i32 @dx.op.unary.i32(i32 31, i32 %{{.*}})
+; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i32(i32 31, i32 %{{.*}})
+; CHECK-NEXT: ret i32 [[A]]
   %elt.ctpop = call i32 @llvm.ctpop.i32(i32 %a)
   ret i32 %elt.ctpop
 }
 
 define noundef i64 @test_countbits_long(i64 noundef %a) {
 entry:
-; CHECK: call i64 @dx.op.unary.i64(i32 31, i64 %{{.*}})
+; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i64(i32 31, i64 %{{.*}})
+; CHECK-NEXT: [[B:%.*]] = zext i32 [[A]] to i64
+; CHECK-NEXT ret i64 [[B]]
   %elt.ctpop = call i64 @llvm.ctpop.i64(i64 %a)
   ret i64 %elt.ctpop
 }
 
+define noundef i32 @test_countbits_long2(i64 noundef %a) {
+entry:
+; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i64(i32 31, i64 %{{.*}})
+; CHECK-NEXT: ret i32 [[A]]
+  %elt.ctpop = call i64 @llvm.ctpop.i64(i64 %a)
+  %elt.trunc = trunc i64 %elt.ctpop to i32
+  ret i32 %elt.trunc
+}
+
 define noundef <4 x i32> @countbits_vec4_i32(<4 x i32> noundef %a)  {
 entry:
   ; CHECK: [[ee0:%.*]] = extractelement <4 x i32> %a, i64 0
-  ; CHECK: [[ie0:%.*]] = call i32 @dx.op.unary.i32(i32 31, i32 [[ee0]])
+  ; CHECK: [[ie0:%.*]] = call i32 @dx.op.unaryBits.i32(i32 31, i32 [[ee0]])
   ; CHECK: [[ee1:%.*]] = extractelement <4 x i32> %a, i64 1
-  ; CHECK: [[ie1:%.*]] = call i32 @dx.op.unary.i32(i32 31, i32 [[ee1]])
+  ; CHECK: [[ie1:%.*]] = call i32 @dx.op.unaryBits.i32(i32 31, i32 [[ee1]])
   ; CHECK: [[ee2:%.*]] = extractelement <4 x i32> %a, i64 2
-  ; CHECK: [[ie2:%.*]] = call i32 @dx.op.unary.i32(i32 31, i32 [[ee2]])
+  ; CHECK: [[ie2:%.*]] = call i32 @dx.op.unaryBits.i32(i32 31, i32 [[ee2]])
   ; CHECK: [[ee3:%.*]] = extractelement <4 x i32> %a, i64 3
-  ; CHECK: [[ie3:%.*]] = call i32 @dx.op.unary.i32(i32 31, i32 [[ee3]])
+  ; CHECK: [[ie3:%.*]] = call i32 @dx.op.unaryBits.i32(i32 31, i32 [[ee3]])
   ; CHECK: insertelement <4 x i32> poison, i32 [[ie0]], i64 0
   ; CHECK: insertelement <4 x i32> %{{.*}}, i32 [[ie1]], i64 1
   ; CHECK: insertelement <4 x i32> %{{.*}}, i32 [[ie2]], i64 2
@@ -44,4 +67,4 @@ entry:
 declare i16 @llvm.ctpop.i16(i16)
 declare i32 @llvm.ctpop.i32(i32)
 declare i64 @llvm.ctpop.i64(i64)
-declare <4 x i32> @llvm.ctpop.v4i32(<4 x i32>)
+declare <4 x i32> @llvm.ctpop.v4i32(<4 x i32>)
\ No newline at end of file

>From e29f40164631ebc6cd3bd5a8168a6c4e13a14a28 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Mon, 21 Oct 2024 16:31:07 +0000
Subject: [PATCH 2/5] make clang format happy

---
 clang/lib/Headers/hlsl/hlsl_intrinsics.h   | 34 +++++------------
 llvm/lib/Target/DirectX/DXILOpLowering.cpp | 44 +++++++++++-----------
 2 files changed, 31 insertions(+), 47 deletions(-)

diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 7936506a0461b6..a3d2be5c41d109 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -756,31 +756,15 @@ constexpr uint4 countbits(uint16_t4 x) {
 }
 #endif
 
-constexpr uint countbits(int x) {
-  return __builtin_elementwise_popcount(x);
-}  
-constexpr uint2 countbits(int2 x) {
-  return __builtin_elementwise_popcount(x);
-}
-constexpr uint3 countbits(int3 x) {
-  return __builtin_elementwise_popcount(x);
-}
-constexpr uint4 countbits(int4 x) {
-  return __builtin_elementwise_popcount(x);
-}
-
-constexpr uint countbits(uint x) {
-  return __builtin_elementwise_popcount(x);
-}
-constexpr uint2 countbits(uint2 x) {
-  return __builtin_elementwise_popcount(x);
-}
-constexpr uint3 countbits(uint3 x) {
-  return __builtin_elementwise_popcount(x);
-}
-constexpr uint4 countbits(uint4 x) {
-  return __builtin_elementwise_popcount(x);
-}
+constexpr uint countbits(int x) { return __builtin_elementwise_popcount(x); }
+constexpr uint2 countbits(int2 x) { return __builtin_elementwise_popcount(x); }
+constexpr uint3 countbits(int3 x) { return __builtin_elementwise_popcount(x); }
+constexpr uint4 countbits(int4 x) { return __builtin_elementwise_popcount(x); }
+
+constexpr uint countbits(uint x) { return __builtin_elementwise_popcount(x); }
+constexpr uint2 countbits(uint2 x) { return __builtin_elementwise_popcount(x); }
+constexpr uint3 countbits(uint3 x) { return __builtin_elementwise_popcount(x); }
+constexpr uint4 countbits(uint4 x) { return __builtin_elementwise_popcount(x); }
 
 constexpr uint countbits(int64_t x) {
   return __builtin_elementwise_popcount(x);
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 646bced9460649..beb446738444fe 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -508,7 +508,7 @@ class OpLowerer {
   [[nodiscard]] bool lowerCtpopToCBits(Function &F) {
     IRBuilder<> &IRB = OpBuilder.getIRB();
     Type *Int32Ty = IRB.getInt32Ty();
-    
+
     return replaceFunction(F, [&](CallInst *CI) -> Error {
       IRB.SetInsertPoint(CI);
       SmallVector<Value *> Args;
@@ -518,26 +518,26 @@ class OpLowerer {
       Type *FRT = F.getReturnType();
       if (FRT->isVectorTy()) {
         VectorType *VT = cast<VectorType>(FRT);
-	RetTy = VectorType::get(RetTy, VT);
+        RetTy = VectorType::get(RetTy, VT);
       }
-      
-      Expected<CallInst *> OpCall =
-	OpBuilder.tryCreateOp(dxil::OpCode::CBits, Args, CI->getName(), RetTy);
+
+      Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
+          dxil::OpCode::CBits, Args, CI->getName(), RetTy);
       if (Error E = OpCall.takeError())
-	return E;
+        return E;
 
       // If the result type is 32 bits we can do a direct replacement.
       if (FRT->isIntOrIntVectorTy(32)) {
         CI->replaceAllUsesWith(*OpCall);
-	CI->eraseFromParent();
-	return Error::success();
+        CI->eraseFromParent();
+        return Error::success();
       }
 
       unsigned CastOp;
       if (FRT->isIntOrIntVectorTy(16))
-	CastOp = Instruction::ZExt;
+        CastOp = Instruction::ZExt;
       else // must be 64 bits
-	CastOp = Instruction::Trunc;
+        CastOp = Instruction::Trunc;
 
       // It is correct to replace the ctpop with the dxil op and
       // remove an existing cast iff the cast is the only usage of
@@ -545,21 +545,21 @@ class OpLowerer {
       // can use hasOneUse instead of hasOneUser, because the user
       // we care about should have one operand
       if (CI->hasOneUse()) {
-	User *U = CI->user_back();
-	Instruction *I;
-	if (isa<Instruction>(U) && (I = cast<Instruction>(U)) &&
-	    I->getOpcode() == CastOp && I->getType() == RetTy) {
+        User *U = CI->user_back();
+        Instruction *I;
+        if (isa<Instruction>(U) && (I = cast<Instruction>(U)) &&
+            I->getOpcode() == CastOp && I->getType() == RetTy) {
           I->replaceAllUsesWith(*OpCall);
-	  I->eraseFromParent();
-	  CI->eraseFromParent();
-	  return Error::success();
-	  }
+          I->eraseFromParent();
+          CI->eraseFromParent();
+          return Error::success();
+        }
       }
 
       // It is always correct to replace a ctpop with the dxil op and
       // a cast
-      Value *Cast = IRB.CreateZExtOrTrunc(*OpCall, F.getReturnType(),
-					  "ctpop.cast");
+      Value *Cast =
+          IRB.CreateZExtOrTrunc(*OpCall, F.getReturnType(), "ctpop.cast");
       CI->replaceAllUsesWith(Cast);
       CI->eraseFromParent();
       return Error::success();
@@ -605,8 +605,8 @@ class OpLowerer {
             });
         break;
       case Intrinsic::ctpop:
-	HasErrors |= lowerCtpopToCBits(F);
-	break;
+        HasErrors |= lowerCtpopToCBits(F);
+        break;
       }
       Updated = true;
     }

>From 5e403503acaabe8f20fbbf790192ce1d904bc42d Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Wed, 23 Oct 2024 16:31:04 +0000
Subject: [PATCH 3/5] address PR comments

---
 llvm/lib/Target/DirectX/DXIL.td            |  2 +-
 llvm/lib/Target/DirectX/DXILOpLowering.cpp | 47 ++++++++++++----------
 llvm/test/CodeGen/DirectX/countbits.ll     |  2 +-
 3 files changed, 27 insertions(+), 24 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 1e51a04f4997b1..1e8dc63ffa257e 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -554,7 +554,7 @@ def Rbits :  DXILOp<30, unary> {
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
 }
 
-def CBits :  DXILOp<31, unaryBits> {
+def CountBits :  DXILOp<31, unaryBits> {
   let Doc = "Returns the number of 1 bits in the specified value.";
   let arguments = [OverloadTy];
   let result = Int32Ty;
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index beb446738444fe..7d9da4b8ab76c9 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -505,7 +505,7 @@ class OpLowerer {
     });
   }
 
-  [[nodiscard]] bool lowerCtpopToCBits(Function &F) {
+  [[nodiscard]] bool lowerCtpopToCountBits(Function &F) {
     IRBuilder<> &IRB = OpBuilder.getIRB();
     Type *Int32Ty = IRB.getInt32Ty();
 
@@ -516,13 +516,11 @@ class OpLowerer {
 
       Type *RetTy = Int32Ty;
       Type *FRT = F.getReturnType();
-      if (FRT->isVectorTy()) {
-        VectorType *VT = cast<VectorType>(FRT);
+      if (const auto *VT = dyn_cast<VectorType>(FRT))
         RetTy = VectorType::get(RetTy, VT);
-      }
 
       Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
-          dxil::OpCode::CBits, Args, CI->getName(), RetTy);
+          dxil::OpCode::CountBits, Args, CI->getName(), RetTy);
       if (Error E = OpCall.takeError())
         return E;
 
@@ -536,31 +534,36 @@ class OpLowerer {
       unsigned CastOp;
       if (FRT->isIntOrIntVectorTy(16))
         CastOp = Instruction::ZExt;
-      else // must be 64 bits
+      else { // must be 64 bits
+        assert(FRT->isIntOrIntVectorTy(64) &&
+               "Currently only lowering 16, 32, or 64 bit ctpop to CountBits \
+                is supported.");
         CastOp = Instruction::Trunc;
+      }
 
       // It is correct to replace the ctpop with the dxil op and
-      // remove an existing cast iff the cast is the only usage of
-      // the ctpop
-      // can use hasOneUse instead of hasOneUser, because the user
-      // we care about should have one operand
-      if (CI->hasOneUse()) {
-        User *U = CI->user_back();
+      // remove all casts to i32
+      bool nonCastInstr = false;
+      for (User *User : make_early_inc_range(CI->users())) {
         Instruction *I;
-        if (isa<Instruction>(U) && (I = cast<Instruction>(U)) &&
+        if ((I = dyn_cast<Instruction>(User)) != NULL &&
             I->getOpcode() == CastOp && I->getType() == RetTy) {
           I->replaceAllUsesWith(*OpCall);
           I->eraseFromParent();
-          CI->eraseFromParent();
-          return Error::success();
-        }
+        } else
+          nonCastInstr = true;
+      }
+
+      // It is correct to replace a ctpop with the dxil op and
+      // a cast from i32 to the return type of the ctpop
+      // the cast is emitted here if there is a non-cast to i32
+      // instr which uses the ctpop
+      if (nonCastInstr) {
+        Value *Cast =
+            IRB.CreateZExtOrTrunc(*OpCall, F.getReturnType(), "ctpop.cast");
+        CI->replaceAllUsesWith(Cast);
       }
 
-      // It is always correct to replace a ctpop with the dxil op and
-      // a cast
-      Value *Cast =
-          IRB.CreateZExtOrTrunc(*OpCall, F.getReturnType(), "ctpop.cast");
-      CI->replaceAllUsesWith(Cast);
       CI->eraseFromParent();
       return Error::success();
     });
@@ -605,7 +608,7 @@ class OpLowerer {
             });
         break;
       case Intrinsic::ctpop:
-        HasErrors |= lowerCtpopToCBits(F);
+        HasErrors |= lowerCtpopToCountBits(F);
         break;
       }
       Updated = true;
diff --git a/llvm/test/CodeGen/DirectX/countbits.ll b/llvm/test/CodeGen/DirectX/countbits.ll
index 91f6f560903f01..11675d67ed0cba 100644
--- a/llvm/test/CodeGen/DirectX/countbits.ll
+++ b/llvm/test/CodeGen/DirectX/countbits.ll
@@ -67,4 +67,4 @@ entry:
 declare i16 @llvm.ctpop.i16(i16)
 declare i32 @llvm.ctpop.i32(i32)
 declare i64 @llvm.ctpop.i64(i64)
-declare <4 x i32> @llvm.ctpop.v4i32(<4 x i32>)
\ No newline at end of file
+declare <4 x i32> @llvm.ctpop.v4i32(<4 x i32>)

>From 74f09256ed6441b06049e1e1fc5f20c80c761a07 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Thu, 24 Oct 2024 00:34:22 +0000
Subject: [PATCH 4/5] address latest pr comments + extra tests

---
 .../test/CodeGenHLSL/builtins/countbits.hlsl  | 20 +++++++++++++++++++
 llvm/lib/Target/DirectX/DXILOpLowering.cpp    | 19 ++++++++++--------
 llvm/test/CodeGen/DirectX/countbits.ll        |  9 +++++++++
 3 files changed, 40 insertions(+), 8 deletions(-)

diff --git a/clang/test/CodeGenHLSL/builtins/countbits.hlsl b/clang/test/CodeGenHLSL/builtins/countbits.hlsl
index aa9ef40d7a0dc8..218d8dcd10f8d7 100644
--- a/clang/test/CodeGenHLSL/builtins/countbits.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/countbits.hlsl
@@ -10,6 +10,13 @@ uint test_countbits_ushort(uint16_t p0)
 {
 	return countbits(p0);
 }
+// CHECK-LABEL: test_countbits_short
+// CHECK: [[A:%.*]] = call i16 @llvm.ctpop.i16
+// CHECK-NEXT: sext i16 [[A]] to i32
+uint test_countbits_short(int16_t p0)
+{
+	return countbits(p0);
+}
 // CHECK-LABEL: test_countbits_ushort2
 // CHECK: [[A:%.*]] = call <2 x i16> @llvm.ctpop.v2i16
 // CHECK-NEXT: zext <2 x i16> [[A]] to <2 x i32>
@@ -39,6 +46,12 @@ uint test_countbits_uint(uint p0)
 {
 	return countbits(p0);
 }
+// CHECK-LABEL: test_countbits_int
+// CHECK: call i32 @llvm.ctpop.i32
+uint test_countbits_int(int p0)
+{
+	return countbits(p0);
+}
 // CHECK-LABEL: test_countbits_uint2
 // CHECK: call <2 x i32> @llvm.ctpop.v2i32
 uint2 test_countbits_uint2(uint2 p0)
@@ -65,6 +78,13 @@ uint test_countbits_long(uint64_t p0)
 {
 	return countbits(p0);
 }
+// CHECK-LABEL: test_countbits_slong
+// CHECK: [[A:%.*]] = call i64 @llvm.ctpop.i64
+// CHECK-NEXT: trunc i64 [[A]] to i32
+uint test_countbits_slong(int64_t p0)
+{
+	return countbits(p0);
+}
 // CHECK-LABEL: test_countbits_long2
 // CHECK: [[A:%.*]] = call <2 x i64> @llvm.ctpop.v2i64
 // CHECK-NEXT: trunc <2 x i64> [[A]] to <2 x i32>
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 7d9da4b8ab76c9..8acc9c1efa08c0 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -532,33 +532,36 @@ class OpLowerer {
       }
 
       unsigned CastOp;
-      if (FRT->isIntOrIntVectorTy(16))
+      unsigned CastOp2;
+      if (FRT->isIntOrIntVectorTy(16)) {
         CastOp = Instruction::ZExt;
-      else { // must be 64 bits
+        CastOp2 = Instruction::SExt;
+      } else { // must be 64 bits
         assert(FRT->isIntOrIntVectorTy(64) &&
                "Currently only lowering 16, 32, or 64 bit ctpop to CountBits \
                 is supported.");
         CastOp = Instruction::Trunc;
+        CastOp2 = Instruction::Trunc;
       }
 
       // It is correct to replace the ctpop with the dxil op and
       // remove all casts to i32
-      bool nonCastInstr = false;
+      bool NeedsCast = false;
       for (User *User : make_early_inc_range(CI->users())) {
-        Instruction *I;
-        if ((I = dyn_cast<Instruction>(User)) != NULL &&
-            I->getOpcode() == CastOp && I->getType() == RetTy) {
+        Instruction *I = dyn_cast<Instruction>(User);
+        if (I && (I->getOpcode() == CastOp || I->getOpcode() == CastOp2) &&
+            I->getType() == RetTy) {
           I->replaceAllUsesWith(*OpCall);
           I->eraseFromParent();
         } else
-          nonCastInstr = true;
+          NeedsCast = true;
       }
 
       // It is correct to replace a ctpop with the dxil op and
       // a cast from i32 to the return type of the ctpop
       // the cast is emitted here if there is a non-cast to i32
       // instr which uses the ctpop
-      if (nonCastInstr) {
+      if (NeedsCast) {
         Value *Cast =
             IRB.CreateZExtOrTrunc(*OpCall, F.getReturnType(), "ctpop.cast");
         CI->replaceAllUsesWith(Cast);
diff --git a/llvm/test/CodeGen/DirectX/countbits.ll b/llvm/test/CodeGen/DirectX/countbits.ll
index 11675d67ed0cba..f03ab9c5e79c35 100644
--- a/llvm/test/CodeGen/DirectX/countbits.ll
+++ b/llvm/test/CodeGen/DirectX/countbits.ll
@@ -20,6 +20,15 @@ entry:
   ret i32 %elt.zext
 }
 
+define noundef i32 @test_countbits_short3(i16 noundef %a) {
+entry:
+; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i16(i32 31, i16 %{{.*}})
+; CHECK-NEXT: ret i32 [[A]]
+  %elt.ctpop = call i16 @llvm.ctpop.i16(i16 %a)
+  %elt.sext = sext i16 %elt.ctpop to i32
+  ret i32 %elt.sext
+}
+
 define noundef i32 @test_countbits_int(i32 noundef %a) {
 entry:
 ; CHECK: [[A:%.*]] = call i32 @dx.op.unaryBits.i32(i32 31, i32 %{{.*}})

>From d160782624de89df5554f04e3b411da173380b42 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Mon, 28 Oct 2024 22:02:21 +0000
Subject: [PATCH 5/5] remove use of constexpr and replace with const inline

---
 clang/lib/Headers/hlsl/hlsl_intrinsics.h | 62 +++++++++++++++---------
 1 file changed, 38 insertions(+), 24 deletions(-)

diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index a3d2be5c41d109..d9f3a17ea23d8e 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -723,72 +723,86 @@ float4 cosh(float4);
 
 #ifdef __HLSL_ENABLE_16_BIT
 _HLSL_AVAILABILITY(shadermodel, 6.2)
-constexpr uint countbits(int16_t x) {
+const inline uint countbits(int16_t x) {
   return __builtin_elementwise_popcount(x);
 }
 _HLSL_AVAILABILITY(shadermodel, 6.2)
-constexpr uint2 countbits(int16_t2 x) {
+const inline uint2 countbits(int16_t2 x) {
   return __builtin_elementwise_popcount(x);
 }
 _HLSL_AVAILABILITY(shadermodel, 6.2)
-constexpr uint3 countbits(int16_t3 x) {
+const inline uint3 countbits(int16_t3 x) {
   return __builtin_elementwise_popcount(x);
 }
 _HLSL_AVAILABILITY(shadermodel, 6.2)
-constexpr uint4 countbits(int16_t4 x) {
+const inline uint4 countbits(int16_t4 x) {
   return __builtin_elementwise_popcount(x);
 }
 _HLSL_AVAILABILITY(shadermodel, 6.2)
-constexpr uint countbits(uint16_t x) {
+const inline uint countbits(uint16_t x) {
   return __builtin_elementwise_popcount(x);
 }
 _HLSL_AVAILABILITY(shadermodel, 6.2)
-constexpr uint2 countbits(uint16_t2 x) {
+const inline uint2 countbits(uint16_t2 x) {
   return __builtin_elementwise_popcount(x);
 }
 _HLSL_AVAILABILITY(shadermodel, 6.2)
-constexpr uint3 countbits(uint16_t3 x) {
+const inline uint3 countbits(uint16_t3 x) {
   return __builtin_elementwise_popcount(x);
 }
 _HLSL_AVAILABILITY(shadermodel, 6.2)
-constexpr uint4 countbits(uint16_t4 x) {
+const inline uint4 countbits(uint16_t4 x) {
   return __builtin_elementwise_popcount(x);
 }
 #endif
 
-constexpr uint countbits(int x) { return __builtin_elementwise_popcount(x); }
-constexpr uint2 countbits(int2 x) { return __builtin_elementwise_popcount(x); }
-constexpr uint3 countbits(int3 x) { return __builtin_elementwise_popcount(x); }
-constexpr uint4 countbits(int4 x) { return __builtin_elementwise_popcount(x); }
+const inline uint countbits(int x) { return __builtin_elementwise_popcount(x); }
+const inline uint2 countbits(int2 x) {
+  return __builtin_elementwise_popcount(x);
+}
+const inline uint3 countbits(int3 x) {
+  return __builtin_elementwise_popcount(x);
+}
+const inline uint4 countbits(int4 x) {
+  return __builtin_elementwise_popcount(x);
+}
 
-constexpr uint countbits(uint x) { return __builtin_elementwise_popcount(x); }
-constexpr uint2 countbits(uint2 x) { return __builtin_elementwise_popcount(x); }
-constexpr uint3 countbits(uint3 x) { return __builtin_elementwise_popcount(x); }
-constexpr uint4 countbits(uint4 x) { return __builtin_elementwise_popcount(x); }
+const inline uint countbits(uint x) {
+  return __builtin_elementwise_popcount(x);
+}
+const inline uint2 countbits(uint2 x) {
+  return __builtin_elementwise_popcount(x);
+}
+const inline uint3 countbits(uint3 x) {
+  return __builtin_elementwise_popcount(x);
+}
+const inline uint4 countbits(uint4 x) {
+  return __builtin_elementwise_popcount(x);
+}
 
-constexpr uint countbits(int64_t x) {
+const inline uint countbits(int64_t x) {
   return __builtin_elementwise_popcount(x);
 }
-constexpr uint2 countbits(int64_t2 x) {
+const inline uint2 countbits(int64_t2 x) {
   return __builtin_elementwise_popcount(x);
 }
-constexpr uint3 countbits(int64_t3 x) {
+const inline uint3 countbits(int64_t3 x) {
   return __builtin_elementwise_popcount(x);
 }
-constexpr uint4 countbits(int64_t4 x) {
+const inline uint4 countbits(int64_t4 x) {
   return __builtin_elementwise_popcount(x);
 }
 
-constexpr uint countbits(uint64_t x) {
+const inline uint countbits(uint64_t x) {
   return __builtin_elementwise_popcount(x);
 }
-constexpr uint2 countbits(uint64_t2 x) {
+const inline uint2 countbits(uint64_t2 x) {
   return __builtin_elementwise_popcount(x);
 }
-constexpr uint3 countbits(uint64_t3 x) {
+const inline uint3 countbits(uint64_t3 x) {
   return __builtin_elementwise_popcount(x);
 }
-constexpr uint4 countbits(uint64_t4 x) {
+const inline uint4 countbits(uint64_t4 x) {
   return __builtin_elementwise_popcount(x);
 }
 



More information about the llvm-commits mailing list