[clang] Implement resource binding type prefix mismatch flag setting logic (PR #97103)
Xiang Li via cfe-commits
cfe-commits at lists.llvm.org
Fri Jun 28 16:50:31 PDT 2024
================
@@ -437,7 +460,206 @@ void SemaHLSL::handleShaderAttr(Decl *D, const ParsedAttr &AL) {
D->addAttr(NewAttr);
}
+struct register_binding_flags {
+ bool resource = false;
+ bool udt = false;
+ bool other = false;
+ bool basic = false;
+
+ bool srv = false;
+ bool uav = false;
+ bool cbv = false;
+ bool sampler = false;
+
+ bool contains_numeric = false;
+ bool default_globals = false;
+};
+
+bool isDeclaredWithinCOrTBuffer(const Decl *decl) {
+ if (!decl)
+ return false;
+
+ // Traverse up the parent contexts
+ const DeclContext *context = decl->getDeclContext();
+ while (context) {
+ if (isa<HLSLBufferDecl>(context)) {
+ return true;
+ }
+ context = context->getParent();
+ }
+
+ return false;
+}
+
+const HLSLResourceAttr *
+getHLSLResourceAttrFromVarDecl(VarDecl *SamplerUAVOrSRV) {
+ const Type *Ty = SamplerUAVOrSRV->getType()->getPointeeOrArrayElementType();
+ if (!Ty)
+ llvm_unreachable("Resource class must have an element type.");
+
+ if (const BuiltinType *BTy = dyn_cast<BuiltinType>(Ty)) {
+ /* QualType QT = SamplerUAVOrSRV->getType();
+ PrintingPolicy PP = S.getPrintingPolicy();
+ std::string typestr = QualType::getAsString(QT.split(), PP);
+
+ S.Diag(ArgLoc, diag::err_hlsl_unsupported_register_resource_type)
+ << typestr;
+ return; */
+ return nullptr;
+ }
+
+ const CXXRecordDecl *TheRecordDecl = Ty->getAsCXXRecordDecl();
+ if (!TheRecordDecl)
+ llvm_unreachable("Resource class should have a resource type declaration.");
+
+ if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(TheRecordDecl))
+ TheRecordDecl = TDecl->getSpecializedTemplate()->getTemplatedDecl();
+ TheRecordDecl = TheRecordDecl->getCanonicalDecl();
+ const auto *Attr = TheRecordDecl->getAttr<HLSLResourceAttr>();
+ return Attr;
+}
+
+void traverseType(QualType T, register_binding_flags &r) {
+ if (T->isIntegralOrEnumerationType() || T->isFloatingType()) {
+ r.contains_numeric = true;
+ return;
+ } else if (const RecordType *RT = T->getAs<RecordType>()) {
+ RecordDecl *SubRD = RT->getDecl();
+ if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(SubRD)) {
+ auto TheRecordDecl = TDecl->getSpecializedTemplate()->getTemplatedDecl();
+ TheRecordDecl = TheRecordDecl->getCanonicalDecl();
+ const auto *Attr = TheRecordDecl->getAttr<HLSLResourceAttr>();
+ llvm::hlsl::ResourceClass DeclResourceClass = Attr->getResourceClass();
+ switch (DeclResourceClass) {
+ case llvm::hlsl::ResourceClass::SRV: {
+ r.srv = true;
+ break;
+ }
+ case llvm::hlsl::ResourceClass::UAV: {
+ r.uav = true;
+ break;
+ }
+ case llvm::hlsl::ResourceClass::CBuffer: {
+ r.cbv = true;
+ break;
+ }
+ case llvm::hlsl::ResourceClass::Sampler: {
+ r.sampler = true;
+ break;
+ }
+ }
+ }
+
+ else if (SubRD->isCompleteDefinition()) {
+ for (auto Field : SubRD->fields()) {
+ QualType T = Field->getType();
+ traverseType(T, r);
+ }
+ }
+ }
+}
+
+void setResourceClassFlagsFromRecordDecl(register_binding_flags &r,
+ const RecordDecl *RD) {
+ if (!RD)
+ return;
+
+ if (RD->isCompleteDefinition()) {
+ for (auto Field : RD->fields()) {
+ QualType T = Field->getType();
----------------
python3kgae wrote:
Why do we iterate the fields here?
https://github.com/llvm/llvm-project/pull/97103
More information about the cfe-commits
mailing list