[llvm-branch-commits] [llvm] [DirectX] Lower `@llvm.dx.handle.fromBinding` to DXIL ops (PR #104251)

Justin Bogner via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Aug 23 12:57:31 PDT 2024


================
@@ -119,6 +123,142 @@ class OpLowerer {
     });
   }
 
+  /// Create a cast between a `target("dx")` type and `dx.types.Handle`, which
+  /// is intended to be removed by the end of lowering. This is used to allow
+  /// lowering of ops which need to change their return or argument types in a
+  /// piecemeal way - we can add the casts in to avoid updating all of the uses
+  /// or defs, and by the end all of the casts will be redundant.
+  Value *createTmpHandleCast(Value *V, Type *Ty) {
+    Function *CastFn = Intrinsic::getDeclaration(&M, Intrinsic::dx_cast_handle,
+                                                 {Ty, V->getType()});
+    CallInst *Cast = OpBuilder.getIRB().CreateCall(CastFn, {V});
+    CleanupCasts.push_back(Cast);
+    return Cast;
+  }
+
+  void cleanupHandleCasts() {
+    SmallVector<CallInst *> ToRemove;
+    SmallVector<Function *> CastFns;
+
+    for (CallInst *Cast : CleanupCasts) {
+      // These casts were only put in to ease the move from `target("dx")` types
+      // to `dx.types.Handle in a piecemeal way. At this point, all of the
+      // non-cast uses should now be `dx.types.Handle`, and remaining casts
+      // should all form pairs to and from the now unused `target("dx")` type.
+      CastFns.push_back(Cast->getCalledFunction());
+
+      // If the cast is not to `dx.types.Handle`, it should be the first part of
+      // the pair. Keep track so we can remove it once it has no more uses.
+      if (Cast->getType() != OpBuilder.getHandleType()) {
+        ToRemove.push_back(Cast);
+        continue;
+      }
+      // Otherwise, we're the second handle in a pair. Forward the arguments and
+      // remove the (second) cast.
+      CallInst *Def = cast<CallInst>(Cast->getOperand(0));
+      assert(Def->getIntrinsicID() == Intrinsic::dx_cast_handle &&
+             "Unbalanced pair of temporary handle casts");
+      Cast->replaceAllUsesWith(Def->getOperand(0));
+      Cast->eraseFromParent();
+    }
+    for (CallInst *Cast : ToRemove) {
+      assert(Cast->user_empty() && "Temporary handle cast still has users");
+      Cast->eraseFromParent();
+    }
+
+    // Deduplicate the cast functions so that we only erase each one once.
+    llvm::sort(CastFns);
+    CastFns.erase(llvm::unique(CastFns), CastFns.end());
+    for (Function *F : CastFns)
+      F->eraseFromParent();
+
+    CleanupCasts.clear();
+  }
+
+  void lowerToCreateHandle(Function &F) {
+    IRBuilder<> &IRB = OpBuilder.getIRB();
+    Type *Int8Ty = IRB.getInt8Ty();
+    Type *Int32Ty = IRB.getInt32Ty();
+
+    replaceFunction(F, [&](CallInst *CI) -> Error {
+      IRB.SetInsertPoint(CI);
+
+      auto *It = DRM.find(CI);
+      assert(It != DRM.end() && "Resource not in map?");
+      dxil::ResourceInfo &RI = *It;
+      const auto &Binding = RI.getBinding();
+
+      std::array<Value *, 4> Args{
+          ConstantInt::get(Int8Ty, llvm::to_underlying(RI.getResourceClass())),
+          ConstantInt::get(Int32Ty, Binding.RecordID), CI->getArgOperand(3),
+          CI->getArgOperand(4)};
+      Expected<CallInst *> OpCall =
+          OpBuilder.tryCreateOp(OpCode::CreateHandle, Args);
+      if (Error E = OpCall.takeError())
+        return E;
+
+      Value *Cast = createTmpHandleCast(*OpCall, CI->getType());
+
+      CI->replaceAllUsesWith(Cast);
+      CI->eraseFromParent();
+      return Error::success();
+    });
+  }
+
+  void lowerToBindAndAnnotateHandle(Function &F) {
+    IRBuilder<> &IRB = OpBuilder.getIRB();
+
+    replaceFunction(F, [&](CallInst *CI) -> Error {
+      IRB.SetInsertPoint(CI);
+
+      auto *It = DRM.find(CI);
+      assert(It != DRM.end() && "Resource not in map?");
+      dxil::ResourceInfo &RI = *It;
+
+      const auto &Binding = RI.getBinding();
+      std::pair<uint32_t, uint32_t> Props = RI.getAnnotateProps();
+
+      // For `CreateHandleFromBinding` we need the upper bound rather than the
+      // size, so we need to be careful about the difference for "unbounded".
+      uint32_t Unbounded = std::numeric_limits<uint32_t>::max();
----------------
bogner wrote:

I personally think it's clearer to spell out that it's the maximum uint32_t rather than have to chase through some global in a header to figure out the sigil for something simple like this.

https://github.com/llvm/llvm-project/pull/104251


More information about the llvm-branch-commits mailing list