[clang] [llvm] [HLSL] Re-implement countbits with the correct return type (PR #113189)

Justin Bogner via cfe-commits cfe-commits at lists.llvm.org
Thu Oct 24 14:19:45 PDT 2024


================
@@ -460,6 +460,70 @@ class OpLowerer {
     });
   }
 
+  [[nodiscard]] bool lowerCtpopToCountBits(Function &F) {
+    IRBuilder<> &IRB = OpBuilder.getIRB();
+    Type *Int32Ty = IRB.getInt32Ty();
+
+    return replaceFunction(F, [&](CallInst *CI) -> Error {
+      IRB.SetInsertPoint(CI);
+      SmallVector<Value *> Args;
+      Args.append(CI->arg_begin(), CI->arg_end());
+
+      Type *RetTy = Int32Ty;
+      Type *FRT = F.getReturnType();
+      if (const auto *VT = dyn_cast<VectorType>(FRT))
+        RetTy = VectorType::get(RetTy, VT);
+
+      Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
+          dxil::OpCode::CountBits, Args, CI->getName(), RetTy);
+      if (Error E = OpCall.takeError())
+        return E;
+
+      // If the result type is 32 bits we can do a direct replacement.
+      if (FRT->isIntOrIntVectorTy(32)) {
+        CI->replaceAllUsesWith(*OpCall);
+        CI->eraseFromParent();
+        return Error::success();
+      }
+
+      unsigned CastOp;
+      if (FRT->isIntOrIntVectorTy(16))
+        CastOp = Instruction::ZExt;
+      else { // must be 64 bits
+        assert(FRT->isIntOrIntVectorTy(64) &&
+               "Currently only lowering 16, 32, or 64 bit ctpop to CountBits \
+                is supported.");
+        CastOp = Instruction::Trunc;
+      }
+
+      // It is correct to replace the ctpop with the dxil op and
+      // remove all casts to i32
+      bool nonCastInstr = false;
+      for (User *User : make_early_inc_range(CI->users())) {
+        Instruction *I;
+        if ((I = dyn_cast<Instruction>(User)) != NULL &&
+            I->getOpcode() == CastOp && I->getType() == RetTy) {
----------------
bogner wrote:

Burying the assignment in the condition makes this a bit less readable. Probably best to initialize `I` before the condition. Also, it's better to use `nullptr_t` or just use the implicit conversion to bool rather than compare to `NULL`.
```suggestion
        Instruction *I = dyn_cast<Instruction>(User)
        if (I && I->getOpcode() == CastOp && I->getType() == RetTy) {
```

https://github.com/llvm/llvm-project/pull/113189


More information about the cfe-commits mailing list