[clang] Implement resource binding type prefix mismatch diagnostic infrastructure (PR #97103)

Damyan Pepper via cfe-commits cfe-commits at lists.llvm.org
Tue Jul 9 18:03:28 PDT 2024


================
@@ -437,7 +444,419 @@ void SemaHLSL::handleShaderAttr(Decl *D, const ParsedAttr &AL) {
     D->addAttr(NewAttr);
 }
 
+struct RegisterBindingFlags {
+  bool Resource = false;
+  bool UDT = false;
+  bool Other = false;
+  bool Basic = false;
+
+  bool SRV = false;
+  bool UAV = false;
+  bool CBV = false;
+  bool Sampler = false;
+
+  bool ContainsNumeric = false;
+  bool DefaultGlobals = false;
+};
+
+bool isDeclaredWithinCOrTBuffer(const Decl *decl) {
+  if (!decl)
+    return false;
+
+  // Traverse up the parent contexts
+  const DeclContext *context = decl->getDeclContext();
+  if (isa<HLSLBufferDecl>(context)) {
+    return true;
+  }
+
+  return false;
+}
+
+const CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) {
+  const Type *Ty = VD->getType()->getPointeeOrArrayElementType();
+  assert(Ty && "Resource class must have an element type.");
+
+  if (const auto *BTy = dyn_cast<BuiltinType>(Ty))
+    return nullptr;
+
+  const CXXRecordDecl *TheRecordDecl = Ty->getAsCXXRecordDecl();
+  assert(TheRecordDecl &&
+         "Resource class should have a resource type declaration.");
+
+  if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(TheRecordDecl))
+    TheRecordDecl = TDecl->getSpecializedTemplate()->getTemplatedDecl();
+  TheRecordDecl = TheRecordDecl->getCanonicalDecl();
+  return TheRecordDecl;
+}
+
+const HLSLResourceAttr *
+getHLSLResourceAttrFromEitherDecl(VarDecl *VD,
+                                  HLSLBufferDecl *CBufferOrTBuffer) {
+
+  if (VD) {
+    const CXXRecordDecl *TheRecordDecl = getRecordDeclFromVarDecl(VD);
+    if (!TheRecordDecl)
+      return nullptr;
+    const auto *Attr = TheRecordDecl->getAttr<HLSLResourceAttr>();
+    return Attr;
+  } else if (CBufferOrTBuffer) {
+    const auto *Attr = CBufferOrTBuffer->getAttr<HLSLResourceAttr>();
+    return Attr;
+  }
+  llvm_unreachable("one of the two conditions should be true.");
+  return nullptr;
+}
+
+void traverseType(QualType T, RegisterBindingFlags &r) {
+  if (T->isIntegralOrEnumerationType() || T->isFloatingType()) {
+    r.ContainsNumeric = true;
+    return;
+  }
+  const RecordType *RT = T->getAs<RecordType>();
+  if (!RT)
+    return;
+
+  RecordDecl *SubRD = RT->getDecl();
+  if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(SubRD)) {
+    auto TheRecordDecl = TDecl->getSpecializedTemplate()->getTemplatedDecl();
+    TheRecordDecl = TheRecordDecl->getCanonicalDecl();
+    const auto *Attr = TheRecordDecl->getAttr<HLSLResourceAttr>();
+    llvm::hlsl::ResourceClass DeclResourceClass = Attr->getResourceClass();
+    switch (DeclResourceClass) {
+    case llvm::hlsl::ResourceClass::SRV:
+      r.SRV = true;
+      break;
+    case llvm::hlsl::ResourceClass::UAV:
+      r.UAV = true;
+      break;
+    case llvm::hlsl::ResourceClass::CBuffer:
+      r.CBV = true;
+      break;
+    case llvm::hlsl::ResourceClass::Sampler:
+      r.Sampler = true;
+      break;
+    }
+  }
+
+  else if (SubRD->isCompleteDefinition()) {
+    for (auto Field : SubRD->fields()) {
+      QualType T = Field->getType();
+      traverseType(T, r);
+    }
+  }
+}
+
+void setResourceClassFlagsFromRecordDecl(RegisterBindingFlags &r,
+                                         const RecordDecl *RD) {
+  if (!RD)
+    return;
+
+  if (RD->isCompleteDefinition()) {
+    for (auto Field : RD->fields()) {
+      QualType T = Field->getType();
+      traverseType(T, r);
+    }
+  }
+}
+
+RegisterBindingFlags HLSLFillRegisterBindingFlags(Sema &S, Decl *D) {
+  RegisterBindingFlags r;
+  if (!isDeclaredWithinCOrTBuffer(D)) {
+    // make sure the type is a basic / numeric type
+    if (VarDecl *v = dyn_cast<VarDecl>(D)) {
+      QualType t = v->getType();
+      // a numeric variable will inevitably end up in $Globals buffer
+      if (t->isIntegralType(S.getASTContext()) || t->isFloatingType())
+        r.DefaultGlobals = true;
+    }
+  }
+  // Cbuffers and Tbuffers are HLSLBufferDecl types
+  HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D);
+  // Samplers, UAVs, and SRVs are VarDecl types
+  VarDecl *VD = dyn_cast<VarDecl>(D);
+
+  assert(((VD && !CBufferOrTBuffer) || (!VD && CBufferOrTBuffer)) &&
+         "either VD or CBufferOrTBuffer should be set");
+
+  if (CBufferOrTBuffer) {
+    r.Resource = true;
+    if (CBufferOrTBuffer->isCBuffer())
+      r.CBV = true;
+    else
+      r.SRV = true;
+  } else if (VD) {
+    const HLSLResourceAttr *res_attr =
+        getHLSLResourceAttrFromEitherDecl(VD, CBufferOrTBuffer);
+    if (res_attr) {
+      llvm::hlsl::ResourceClass DeclResourceClass =
+          res_attr->getResourceClass();
+      r.Resource = true;
+      switch (DeclResourceClass) {
+      case llvm::hlsl::ResourceClass::SRV: {
----------------
damyanp wrote:

The braces around all the cases here aren't necessary.

https://github.com/llvm/llvm-project/pull/97103


More information about the cfe-commits mailing list