[clang] [llvm] [HLSL] Add handle initialization for simple resource declarations (PR #111207)

Helena Kotas via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 16 21:28:37 PDT 2024


https://github.com/hekota updated https://github.com/llvm/llvm-project/pull/111207

>From a13f62d2b5cf1bd1ee7016fce5e0fd95531bf7a2 Mon Sep 17 00:00:00 2001
From: Helena Kotas <hekotas at microsoft.com>
Date: Fri, 4 Oct 2024 13:19:27 -0700
Subject: [PATCH 1/4] [HLSL] Add handle initialization for simple resource
 declarations

Adds `@_init_resource_bindings()` function to module initialization that
includes `handle.fromBinding` intrinsic calls for simple resource declarations.
Arrays of resources or resources inside user defined types are not supported yet.
---
 clang/lib/CodeGen/CGDeclCXX.cpp               |   5 +
 clang/lib/CodeGen/CGHLSLRuntime.cpp           | 111 ++++++++++++++++++
 clang/lib/CodeGen/CGHLSLRuntime.h             |   7 ++
 clang/lib/CodeGen/CodeGenModule.cpp           |   3 +
 .../builtins/RWBuffer-constructor.hlsl        |  26 ++--
 .../StructuredBuffer-constructor.hlsl         |  27 +++--
 llvm/include/llvm/IR/IntrinsicsSPIRV.td       |  10 ++
 7 files changed, 164 insertions(+), 25 deletions(-)

diff --git a/clang/lib/CodeGen/CGDeclCXX.cpp b/clang/lib/CodeGen/CGDeclCXX.cpp
index 8dcb5f61006196..834c5b2d65db42 100644
--- a/clang/lib/CodeGen/CGDeclCXX.cpp
+++ b/clang/lib/CodeGen/CGDeclCXX.cpp
@@ -1121,6 +1121,11 @@ CodeGenFunction::GenerateCXXGlobalInitFunc(llvm::Function *Fn,
       if (Decls[i])
         EmitRuntimeCall(Decls[i]);
 
+    if (getLangOpts().HLSL)
+      if (llvm::Function *ResInitFn =
+              CGM.getHLSLRuntime().createResourceBindingInitFn())
+        Builder.CreateCall(llvm::FunctionCallee(ResInitFn), {});
+
     Scope.ForceCleanup();
 
     if (ExitBlock) {
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp
index 3237d93ca31ceb..23ed24eaf5cb27 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.cpp
+++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp
@@ -18,8 +18,13 @@
 #include "TargetInfo.h"
 #include "clang/AST/Decl.h"
 #include "clang/Basic/TargetOptions.h"
+#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/Module.h"
+#include "llvm/IR/Value.h"
+#include "llvm/Support/Alignment.h"
+
 #include "llvm/Support/FormatVariadic.h"
 
 using namespace clang;
@@ -489,3 +494,109 @@ void CGHLSLRuntime::generateGlobalCtorDtorCalls() {
       GV->eraseFromParent();
   }
 }
+
+// Returns handle type of a resource, if the VarDecl is a resource
+// or an array of resources
+static const HLSLAttributedResourceType *
+findHandleTypeOnResource(const VarDecl *VD) {
+  // If VarDecl is a resource class, the first field must
+  // be the resource handle of type HLSLAttributedResourceType
+  assert(VD != nullptr && "expected VarDecl");
+  const clang::Type *Ty = VD->getType()->getPointeeOrArrayElementType();
+  if (RecordDecl *RD = Ty->getAsCXXRecordDecl()) {
+    if (!RD->fields().empty()) {
+      const auto &FirstFD = RD->fields().begin();
+      return dyn_cast<HLSLAttributedResourceType>(
+          FirstFD->getType().getTypePtr());
+    }
+  }
+  return nullptr;
+}
+
+void CGHLSLRuntime::handleGlobalVarDefinition(const VarDecl *VD,
+                                              llvm::GlobalVariable *Var) {
+  // If the global variable has resource binding, add it to the list of globals
+  // that need resource binding initialization.
+  const HLSLResourceBindingAttr *RBA = VD->getAttr<HLSLResourceBindingAttr>();
+  if (!RBA)
+    return;
+
+  // FIXME: support for resource arrays or resource fields on user defined
+  // classes is not yet implemented
+  if (RBA->ResourceField != nullptr || VD->getType()->isArrayType())
+    return;
+
+  ResourcesToBind.emplace_back(std::make_pair(VD, Var));
+}
+
+llvm::Function *CGHLSLRuntime::createResourceBindingInitFn() {
+  // No resources to bind
+  if (ResourcesToBind.empty())
+    return nullptr;
+
+  LLVMContext &Ctx = CGM.getLLVMContext();
+
+  llvm::Function *InitResBindingsFunc =
+      llvm::Function::Create(llvm::FunctionType::get(CGM.VoidTy, false),
+                             llvm::GlobalValue::InternalLinkage,
+                             "_init_resource_bindings", CGM.getModule());
+
+  llvm::BasicBlock *EntryBB =
+      llvm::BasicBlock::Create(Ctx, "entry", InitResBindingsFunc);
+  CGBuilderTy Builder(CGM, Ctx);
+  const DataLayout &DL = CGM.getModule().getDataLayout();
+  Builder.SetInsertPoint(EntryBB);
+
+  for (auto I : ResourcesToBind) {
+    const VarDecl *VD = I.first;
+    llvm::GlobalVariable *Var = I.second;
+
+    for (Attr *A : VD->getAttrs()) {
+      HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A);
+      if (!RBA)
+        continue;
+
+      if (RBA->getResourceField() != nullptr) {
+        // FIXME: Register bindings inside user defined struct are not yet
+        // supported
+        llvm_unreachable("Register bindings inside user defined struct are not "
+                         "implemented yet");
+        continue;
+      }
+
+      const HLSLAttributedResourceType *AttrResType =
+          findHandleTypeOnResource(VD);
+      assert(AttrResType != nullptr &&
+             "Resource class must have a handle of HLSLAttributedResourceType");
+
+      llvm::Type *TargetTy =
+          CGM.getTargetCodeGenInfo().getHLSLType(CGM, AttrResType);
+      assert(TargetTy != nullptr &&
+             "Failed to convert resource handle to target type");
+
+      llvm::Value *Args[] = {
+          llvm::ConstantInt::get(CGM.IntTy,
+                                 RBA->getSpaceNumber()), /*RegisterSpace*/
+          llvm::ConstantInt::get(CGM.IntTy,
+                                 RBA->getSlotNumber()), /*RegisterSlot*/
+          // FIXME: resource arrays are not yet implemented
+          llvm::ConstantInt::get(CGM.IntTy, 1), /*Range*/
+          llvm::ConstantInt::get(CGM.IntTy, 0), /*Index*/
+          // FIXME: NonUniformResourceIndex bit is not yet implemented
+          llvm::ConstantInt::get(llvm::Type::getInt1Ty(Ctx),
+                                 false) /*Non-uniform*/
+      };
+      llvm::Value *CreateHandle = Builder.CreateIntrinsic(
+          /*ReturnType=*/TargetTy, getCreateHandleFromBindingIntrinsic(), Args,
+          nullptr, Twine(VD->getName()).concat("_h"));
+
+      llvm::Value *HandleRef =
+          Builder.CreateStructGEP(Var->getValueType(), Var, 0);
+      Builder.CreateAlignedStore(CreateHandle, HandleRef,
+                                 HandleRef->getPointerAlignment(DL));
+    }
+  }
+
+  Builder.CreateRetVoid();
+  return InitResBindingsFunc;
+}
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index 6722d2c7c50a2b..0b9d2f165f322b 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -89,6 +89,8 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
 
+  GENERATE_HLSL_INTRINSIC_FUNCTION(CreateHandleFromBinding, handle_fromBinding)
+
   //===----------------------------------------------------------------------===//
   // End of reserved area for HLSL intrinsic getters.
   //===----------------------------------------------------------------------===//
@@ -134,6 +136,8 @@ class CGHLSLRuntime {
 
   void emitEntryFunction(const FunctionDecl *FD, llvm::Function *Fn);
   void setHLSLFunctionAttributes(const FunctionDecl *FD, llvm::Function *Fn);
+  void handleGlobalVarDefinition(const VarDecl *VD, llvm::GlobalVariable *Var);
+  llvm::Function *createResourceBindingInitFn();
 
 private:
   void addBufferResourceAnnotation(llvm::GlobalVariable *GV,
@@ -145,6 +149,9 @@ class CGHLSLRuntime {
   void addBufferDecls(const DeclContext *DC, Buffer &CB);
   llvm::Triple::ArchType getArch();
   llvm::SmallVector<Buffer> Buffers;
+
+  llvm::SmallVector<std::pair<const VarDecl *, llvm::GlobalVariable *>>
+      ResourcesToBind;
 };
 
 } // namespace CodeGen
diff --git a/clang/lib/CodeGen/CodeGenModule.cpp b/clang/lib/CodeGen/CodeGenModule.cpp
index 25c1c496a4f27f..1dd969fb0a4187 100644
--- a/clang/lib/CodeGen/CodeGenModule.cpp
+++ b/clang/lib/CodeGen/CodeGenModule.cpp
@@ -5617,6 +5617,9 @@ void CodeGenModule::EmitGlobalVarDefinition(const VarDecl *D,
     getCUDARuntime().handleVarRegistration(D, *GV);
   }
 
+  if (LangOpts.HLSL)
+    getHLSLRuntime().handleGlobalVarDefinition(D, GV);
+
   GV->setInitializer(Init);
   if (emitter)
     emitter->finalize(GV);
diff --git a/clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl b/clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl
index 19699dcf14d9f4..844edea3d0f319 100644
--- a/clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl
@@ -1,19 +1,21 @@
-// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s
-// RUN: %clang_cc1 -triple spirv-vulkan-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefix=CHECK-SPIRV
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN-DISABLED: %clang_cc1 -triple spirv-vulkan-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
 
-// XFAIL: *
-// This expectedly fails because create.handle is no longer called
-// from RWBuffer constructor and the replacement has not been
-// implemented yet. This test should be updated to expect
-// dx.create.handleFromBinding as part of issue #105076.
+// NOTE: SPIRV codegen for resource types is not yet implemented
 
-RWBuffer<float> Buf;
+RWBuffer<float> Buf : register(u5, space3);
 
 // CHECK: define linkonce_odr noundef ptr @"??0?$RWBuffer at M@hlsl@@QAA at XZ"
 // CHECK-NEXT: entry:
 
-// CHECK: %[[HandleRes:[0-9]+]] = call ptr @llvm.dx.create.handle(i8 1)
-// CHECK: store ptr %[[HandleRes]], ptr %h, align 4
+// CHECK: define internal void @_GLOBAL__sub_I_RWBuffer_constructor.hlsl()
+// CHECK-NEXT: entry:
+// CHECK-NEXT: call void @"??__EBuf@@YAXXZ"()
+// CHECK-NEXT: call void @_init_resource_bindings()
 
-// CHECK-SPIRV: %[[HandleRes:[0-9]+]] = call ptr @llvm.spv.create.handle(i8 1)
-// CHECK-SPIRV: store ptr %[[HandleRes]], ptr %h, align 8
+// CHECK: define internal void @_init_resource_bindings() {
+// CHECK-NEXT: entry:
+// CHECK-DXIL-NEXT: %Buf_h = call target("dx.TypedBuffer", float, 1, 0, 0) @llvm.dx.handle.fromBinding.tdx.TypedBuffer_f32_1_0_0t(i32 3, i32 5, i32 1, i32 0, i1 false)
+// CHECK-DXIL-NEXT: store target("dx.TypedBuffer", float, 1, 0, 0) %Buf_h, ptr @"?Buf@@3V?$RWBuffer at M@hlsl@@A", align 4
+// CHECK-SPIRV-NEXT: %Buf_h = call target("dx.TypedBuffer", float, 1, 0, 0) @llvm.spv.handle.fromBinding.tdx.TypedBuffer_f32_1_0_0t(i32 3, i32 5, i32 1, i32 0, i1 false)
+// CHECK-SPIRV-NEXT: store target("dx.TypedBuffer", float, 1, 0, 0) %Buf_h, ptr @"?Buf@@3V?$RWBuffer at M@hlsl@@A", align 4
diff --git a/clang/test/CodeGenHLSL/builtins/StructuredBuffer-constructor.hlsl b/clang/test/CodeGenHLSL/builtins/StructuredBuffer-constructor.hlsl
index 178332d03e6404..5014d6959d7973 100644
--- a/clang/test/CodeGenHLSL/builtins/StructuredBuffer-constructor.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/StructuredBuffer-constructor.hlsl
@@ -1,19 +1,20 @@
-// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s
-// RUN: %clang_cc1 -triple spirv-vulkan-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefix=CHECK-SPIRV
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN-DISABLED: %clang_cc1 -triple spirv-vulkan-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
 
-// XFAIL: *
-// This expectedly fails because create.handle is no longer invoked
-// from StructuredBuffer constructor and the replacement has not been
-// implemented yet. This test should be updated to expect
-// dx.create.handleFromBinding as part of issue #105076.
-
-StructuredBuffer<float> Buf;
+// NOTE: SPIRV codegen for resource types is not yet implemented
+StructuredBuffer<float> Buf : register(u10);
 
 // CHECK: define linkonce_odr noundef ptr @"??0?$StructuredBuffer at M@hlsl@@QAA at XZ"
 // CHECK-NEXT: entry:
 
-// CHECK: %[[HandleRes:[0-9]+]] = call ptr @llvm.dx.create.handle(i8 1)
-// CHECK: store ptr %[[HandleRes]], ptr %h, align 4
+// CHECK: define internal void @_GLOBAL__sub_I_StructuredBuffer_constructor.hlsl()
+// CHECK-NEXT: entry:
+// CHECK-NEXT: call void @"??__EBuf@@YAXXZ"()
+// CHECK-NEXT: call void @_init_resource_bindings()
 
-// CHECK-SPIRV: %[[HandleRes:[0-9]+]] = call ptr @llvm.spv.create.handle(i8 1)
-// CHECK-SPIRV: store ptr %[[HandleRes]], ptr %h, align 8
+// CHECK: define internal void @_init_resource_bindings() {
+// CHECK-NEXT: entry:
+// CHECK-DXIL-NEXT: %Buf_h = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.handle.fromBinding.tdx.RawBuffer_f32_1_0t(i32 0, i32 10, i32 1, i32 0, i1 false)
+// CHECK-DXIL-NEXT: store target("dx.RawBuffer", float, 1, 0) %Buf_h, ptr @"?Buf@@3V?$StructuredBuffer at M@hlsl@@A", align 4
+// CHECK-SPIRV-NEXT: %Buf_h = call target("dx.RawBuffer", float, 1, 0) @llvm.spv.handle.fromBinding.tdx.RawBuffer_f32_1_0t(i32 0, i32 10, i32 1, i32 0, i1 false)
+// CHECK-SPIRV-NEXT: store target("dx.RawBuffer", float, 1, 0) %Buf_h, ptr @"?Buf@@3V?$StructuredBuffer at M@hlsl@@A", align 4
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 88059aa8378140..bc09d1f34503b5 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -85,4 +85,14 @@ let TargetPrefix = "spv" in {
   def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
   def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
   def int_spv_radians : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
+  
+// Create resource handle given binding information. Returns a target
+// type appropriate for the kind of resource given a register space ID, lower
+// bound and range size of the binding, as well as an index and an indicator
+// whether that index may be non-uniform.
+def int_spv_handle_fromBinding
+    : DefaultAttrsIntrinsic<
+          [llvm_any_ty],
+          [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i1_ty],
+          [IntrNoMem]>;  
 }

>From b83da33a675ec80f61c4771b75beeca24ba80307 Mon Sep 17 00:00:00 2001
From: Helena Kotas <hekotas at microsoft.com>
Date: Thu, 10 Oct 2024 08:28:11 -0700
Subject: [PATCH 2/4] Remove ResourceField use, change findHandleTypeOnResource
 to work on type

---
 clang/lib/CodeGen/CGHLSLRuntime.cpp | 27 +++++++++------------------
 1 file changed, 9 insertions(+), 18 deletions(-)

diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp
index 23ed24eaf5cb27..62f8b81eede480 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.cpp
+++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp
@@ -495,14 +495,12 @@ void CGHLSLRuntime::generateGlobalCtorDtorCalls() {
   }
 }
 
-// Returns handle type of a resource, if the VarDecl is a resource
+// Returns handle type of a resource, if the type is a resource
 // or an array of resources
-static const HLSLAttributedResourceType *
-findHandleTypeOnResource(const VarDecl *VD) {
-  // If VarDecl is a resource class, the first field must
+static const HLSLAttributedResourceType *findHandleTypeOnResource(QualType QT) {
+  // If the type is a resource class, the first field must
   // be the resource handle of type HLSLAttributedResourceType
-  assert(VD != nullptr && "expected VarDecl");
-  const clang::Type *Ty = VD->getType()->getPointeeOrArrayElementType();
+  const clang::Type *Ty = QT->getUnqualifiedDesugaredType();
   if (RecordDecl *RD = Ty->getAsCXXRecordDecl()) {
     if (!RD->fields().empty()) {
       const auto &FirstFD = RD->fields().begin();
@@ -521,9 +519,10 @@ void CGHLSLRuntime::handleGlobalVarDefinition(const VarDecl *VD,
   if (!RBA)
     return;
 
-  // FIXME: support for resource arrays or resource fields on user defined
-  // classes is not yet implemented
-  if (RBA->ResourceField != nullptr || VD->getType()->isArrayType())
+  if (!findHandleTypeOnResource(VD->getType()))
+    // FIXME: Only simple declarations of resources are supported for now.
+    // Arrays of resources or resources in user defined classes are
+    // not implemented yet.
     return;
 
   ResourcesToBind.emplace_back(std::make_pair(VD, Var));
@@ -556,16 +555,8 @@ llvm::Function *CGHLSLRuntime::createResourceBindingInitFn() {
       if (!RBA)
         continue;
 
-      if (RBA->getResourceField() != nullptr) {
-        // FIXME: Register bindings inside user defined struct are not yet
-        // supported
-        llvm_unreachable("Register bindings inside user defined struct are not "
-                         "implemented yet");
-        continue;
-      }
-
       const HLSLAttributedResourceType *AttrResType =
-          findHandleTypeOnResource(VD);
+          findHandleTypeOnResource(VD->getType());
       assert(AttrResType != nullptr &&
              "Resource class must have a handle of HLSLAttributedResourceType");
 

>From 8efaf7ca1fef89506cef30e291011021bbbdea16 Mon Sep 17 00:00:00 2001
From: Helena Kotas <hekotas at microsoft.com>
Date: Tue, 15 Oct 2024 22:14:56 -0700
Subject: [PATCH 3/4] cr feedback

---
 clang/lib/CodeGen/CGDeclCXX.cpp               |  9 ++--
 clang/lib/CodeGen/CGHLSLRuntime.cpp           | 44 ++++++++++---------
 clang/lib/CodeGen/CGHLSLRuntime.h             |  2 +
 .../builtins/RWBuffer-constructor.hlsl        |  1 +
 4 files changed, 32 insertions(+), 24 deletions(-)

diff --git a/clang/lib/CodeGen/CGDeclCXX.cpp b/clang/lib/CodeGen/CGDeclCXX.cpp
index 834c5b2d65db42..b4f1a68cfe87f4 100644
--- a/clang/lib/CodeGen/CGDeclCXX.cpp
+++ b/clang/lib/CodeGen/CGDeclCXX.cpp
@@ -1121,10 +1121,13 @@ CodeGenFunction::GenerateCXXGlobalInitFunc(llvm::Function *Fn,
       if (Decls[i])
         EmitRuntimeCall(Decls[i]);
 
-    if (getLangOpts().HLSL)
-      if (llvm::Function *ResInitFn =
-              CGM.getHLSLRuntime().createResourceBindingInitFn())
+    if (getLangOpts().HLSL) {
+      CGHLSLRuntime &CGHLSL = CGM.getHLSLRuntime();
+      if (CGHLSL.needsResourceBindingInitFn()) {
+        llvm::Function *ResInitFn = CGHLSL.createResourceBindingInitFn();
         Builder.CreateCall(llvm::FunctionCallee(ResInitFn), {});
+      }
+    }
 
     Scope.ForceCleanup();
 
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp
index 62f8b81eede480..c6d551894a033e 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.cpp
+++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp
@@ -512,7 +512,7 @@ static const HLSLAttributedResourceType *findHandleTypeOnResource(QualType QT) {
 }
 
 void CGHLSLRuntime::handleGlobalVarDefinition(const VarDecl *VD,
-                                              llvm::GlobalVariable *Var) {
+                                              llvm::GlobalVariable *GV) {
   // If the global variable has resource binding, add it to the list of globals
   // that need resource binding initialization.
   const HLSLResourceBindingAttr *RBA = VD->getAttr<HLSLResourceBindingAttr>();
@@ -525,15 +525,19 @@ void CGHLSLRuntime::handleGlobalVarDefinition(const VarDecl *VD,
     // not implemented yet.
     return;
 
-  ResourcesToBind.emplace_back(std::make_pair(VD, Var));
+  ResourcesToBind.emplace_back(VD, GV);
+}
+
+bool CGHLSLRuntime::needsResourceBindingInitFn() {
+  return !ResourcesToBind.empty();
 }
 
 llvm::Function *CGHLSLRuntime::createResourceBindingInitFn() {
   // No resources to bind
-  if (ResourcesToBind.empty())
-    return nullptr;
+  assert(needsResourceBindingInitFn() && "no resources to bind");
 
   LLVMContext &Ctx = CGM.getLLVMContext();
+  llvm::Type *Int1Ty = llvm::Type::getInt1Ty(Ctx);
 
   llvm::Function *InitResBindingsFunc =
       llvm::Function::Create(llvm::FunctionType::get(CGM.VoidTy, false),
@@ -546,10 +550,7 @@ llvm::Function *CGHLSLRuntime::createResourceBindingInitFn() {
   const DataLayout &DL = CGM.getModule().getDataLayout();
   Builder.SetInsertPoint(EntryBB);
 
-  for (auto I : ResourcesToBind) {
-    const VarDecl *VD = I.first;
-    llvm::GlobalVariable *Var = I.second;
-
+  for (const auto &[VD, GV] : ResourcesToBind) {
     for (Attr *A : VD->getAttrs()) {
       HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A);
       if (!RBA)
@@ -557,6 +558,10 @@ llvm::Function *CGHLSLRuntime::createResourceBindingInitFn() {
 
       const HLSLAttributedResourceType *AttrResType =
           findHandleTypeOnResource(VD->getType());
+
+      // FIXME: Only simple declarations of resources are supported for now.
+      // Arrays of resources or resources in user defined classes are
+      // not implemented yet.
       assert(AttrResType != nullptr &&
              "Resource class must have a handle of HLSLAttributedResourceType");
 
@@ -565,24 +570,21 @@ llvm::Function *CGHLSLRuntime::createResourceBindingInitFn() {
       assert(TargetTy != nullptr &&
              "Failed to convert resource handle to target type");
 
-      llvm::Value *Args[] = {
-          llvm::ConstantInt::get(CGM.IntTy,
-                                 RBA->getSpaceNumber()), /*RegisterSpace*/
-          llvm::ConstantInt::get(CGM.IntTy,
-                                 RBA->getSlotNumber()), /*RegisterSlot*/
-          // FIXME: resource arrays are not yet implemented
-          llvm::ConstantInt::get(CGM.IntTy, 1), /*Range*/
-          llvm::ConstantInt::get(CGM.IntTy, 0), /*Index*/
-          // FIXME: NonUniformResourceIndex bit is not yet implemented
-          llvm::ConstantInt::get(llvm::Type::getInt1Ty(Ctx),
-                                 false) /*Non-uniform*/
-      };
+      auto *Space = llvm::ConstantInt::get(CGM.IntTy, RBA->getSpaceNumber());
+      auto *Slot = llvm::ConstantInt::get(CGM.IntTy, RBA->getSlotNumber());
+      // FIXME: resource arrays are not yet implemented
+      auto *Range = llvm::ConstantInt::get(CGM.IntTy, 1);
+      auto *Index = llvm::ConstantInt::get(CGM.IntTy, 0);
+      // FIXME: NonUniformResourceIndex bit is not yet implemented
+      auto *NonUniform = llvm::ConstantInt::get(Int1Ty, false);
+      llvm::Value *Args[] = {Space, Slot, Range, Index, NonUniform};
+
       llvm::Value *CreateHandle = Builder.CreateIntrinsic(
           /*ReturnType=*/TargetTy, getCreateHandleFromBindingIntrinsic(), Args,
           nullptr, Twine(VD->getName()).concat("_h"));
 
       llvm::Value *HandleRef =
-          Builder.CreateStructGEP(Var->getValueType(), Var, 0);
+          Builder.CreateStructGEP(GV->getValueType(), GV, 0);
       Builder.CreateAlignedStore(CreateHandle, HandleRef,
                                  HandleRef->getPointerAlignment(DL));
     }
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index 4294fbcfaca5bb..0d29d697ea551c 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -138,6 +138,8 @@ class CGHLSLRuntime {
   void emitEntryFunction(const FunctionDecl *FD, llvm::Function *Fn);
   void setHLSLFunctionAttributes(const FunctionDecl *FD, llvm::Function *Fn);
   void handleGlobalVarDefinition(const VarDecl *VD, llvm::GlobalVariable *Var);
+
+  bool needsResourceBindingInitFn();
   llvm::Function *createResourceBindingInitFn();
 
 private:
diff --git a/clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl b/clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl
index 844edea3d0f319..46f959a63f9510 100644
--- a/clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl
@@ -1,4 +1,5 @@
 // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// FIXME: SPIR-V codegen of llvm.spv.handle.fromBinding is not yet implemented
 // RUN-DISABLED: %clang_cc1 -triple spirv-vulkan-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
 
 // NOTE: SPIRV codegen for resource types is not yet implemented

>From eada96658d715b08ad3ccfbd712afef20b5fe538 Mon Sep 17 00:00:00 2001
From: Helena Kotas <hekotas at microsoft.com>
Date: Wed, 16 Oct 2024 11:21:23 -0700
Subject: [PATCH 4/4] Move findHandleTypeOnResource to static method on
 HLSLAttributedResourceType

---
 clang/include/clang/AST/Type.h      |  4 ++++
 clang/lib/AST/Type.cpp              | 15 +++++++++++++++
 clang/lib/CodeGen/CGHLSLRuntime.cpp | 22 ++++++----------------
 clang/lib/Sema/SemaHLSL.cpp         | 12 ++----------
 4 files changed, 27 insertions(+), 26 deletions(-)

diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index deda5b3f70f343..40e617bf8f3b8d 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -6320,6 +6320,10 @@ class HLSLAttributedResourceType : public Type, public llvm::FoldingSetNode {
   static bool classof(const Type *T) {
     return T->getTypeClass() == HLSLAttributedResource;
   }
+
+  // Returns handle type from HLSL resource, if the type is a resource
+  static const HLSLAttributedResourceType *
+  findHandleTypeOnResource(const Type *RT);
 };
 
 class TemplateTypeParmType : public Type, public llvm::FoldingSetNode {
diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp
index f013ed11d12935..e7493d7cdf0e29 100644
--- a/clang/lib/AST/Type.cpp
+++ b/clang/lib/AST/Type.cpp
@@ -5334,3 +5334,18 @@ std::string FunctionEffectWithCondition::description() const {
     Result += "(expr)";
   return Result;
 }
+
+const HLSLAttributedResourceType *
+HLSLAttributedResourceType::findHandleTypeOnResource(const Type *RT) {
+  // If the type T is an HLSL resource class, the first field must
+  // be the resource handle of type HLSLAttributedResourceType
+  const clang::Type *Ty = RT->getUnqualifiedDesugaredType();
+  if (const RecordDecl *RD = Ty->getAsCXXRecordDecl()) {
+    if (!RD->fields().empty()) {
+      const auto &FirstFD = RD->fields().begin();
+      return dyn_cast<HLSLAttributedResourceType>(
+          FirstFD->getType().getTypePtr());
+    }
+  }
+  return nullptr;
+}
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp
index c6d551894a033e..c934145bd8b3b4 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.cpp
+++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp
@@ -495,20 +495,10 @@ void CGHLSLRuntime::generateGlobalCtorDtorCalls() {
   }
 }
 
-// Returns handle type of a resource, if the type is a resource
-// or an array of resources
-static const HLSLAttributedResourceType *findHandleTypeOnResource(QualType QT) {
-  // If the type is a resource class, the first field must
-  // be the resource handle of type HLSLAttributedResourceType
-  const clang::Type *Ty = QT->getUnqualifiedDesugaredType();
-  if (RecordDecl *RD = Ty->getAsCXXRecordDecl()) {
-    if (!RD->fields().empty()) {
-      const auto &FirstFD = RD->fields().begin();
-      return dyn_cast<HLSLAttributedResourceType>(
-          FirstFD->getType().getTypePtr());
-    }
-  }
-  return nullptr;
+// Returns handle type from a resource, if the type is a resource
+static const HLSLAttributedResourceType *
+findHandleTypeOnResource(const clang::Type *Ty) {
+  return HLSLAttributedResourceType::findHandleTypeOnResource(Ty);
 }
 
 void CGHLSLRuntime::handleGlobalVarDefinition(const VarDecl *VD,
@@ -519,7 +509,7 @@ void CGHLSLRuntime::handleGlobalVarDefinition(const VarDecl *VD,
   if (!RBA)
     return;
 
-  if (!findHandleTypeOnResource(VD->getType()))
+  if (!findHandleTypeOnResource(VD->getType().getTypePtr()))
     // FIXME: Only simple declarations of resources are supported for now.
     // Arrays of resources or resources in user defined classes are
     // not implemented yet.
@@ -557,7 +547,7 @@ llvm::Function *CGHLSLRuntime::createResourceBindingInitFn() {
         continue;
 
       const HLSLAttributedResourceType *AttrResType =
-          findHandleTypeOnResource(VD->getType());
+          findHandleTypeOnResource(VD->getType().getTypePtr());
 
       // FIXME: Only simple declarations of resources are supported for now.
       // Arrays of resources or resources in user defined classes are
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 698fdbed0484e5..84a0655ca30d03 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1000,16 +1000,8 @@ static CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) {
 }
 
 static const HLSLAttributedResourceType *
-findAttributedResourceTypeOnField(VarDecl *VD) {
-  assert(VD != nullptr && "expected VarDecl");
-  if (RecordDecl *RD = getRecordDeclFromVarDecl(VD)) {
-    for (auto *FD : RD->fields()) {
-      if (const HLSLAttributedResourceType *AttrResType =
-              dyn_cast<HLSLAttributedResourceType>(FD->getType().getTypePtr()))
-        return AttrResType;
-    }
-  }
-  return nullptr;
+findHandleTypeOnResource(const Type *Ty) {
+  return HLSLAttributedResourceType::findHandleTypeOnResource(Ty);
 }
 
 // Iterate over RecordType fields and return true if any of them matched the



More information about the llvm-commits mailing list