[clang] Implement resource binding type prefix mismatch diagnostic infrastructure (PR #97103)

Helena Kotas via cfe-commits cfe-commits at lists.llvm.org
Thu Aug 22 14:12:34 PDT 2024


================
@@ -459,7 +467,412 @@ void SemaHLSL::handleResourceClassAttr(Decl *D, const ParsedAttr &AL) {
   D->addAttr(HLSLResourceClassAttr::Create(getASTContext(), RC, ArgLoc));
 }
 
-void SemaHLSL::handleResourceBindingAttr(Decl *D, const ParsedAttr &AL) {
+struct RegisterBindingFlags {
+  bool Resource = false;
+  bool UDT = false;
+  bool Other = false;
+  bool Basic = false;
+
+  bool SRV = false;
+  bool UAV = false;
+  bool CBV = false;
+  bool Sampler = false;
+
+  bool ContainsNumeric = false;
+  bool DefaultGlobals = false;
+};
+
+static bool isDeclaredWithinCOrTBuffer(const Decl *TheDecl) {
+  if (!TheDecl)
+    return false;
+
+  // Traverse up the parent contexts
+  const DeclContext *context = TheDecl->getDeclContext();
+  if (isa<HLSLBufferDecl>(context)) {
+    return true;
+  }
+
+  return false;
+}
+
+// get the record decl from a var decl that we expect
+// represents a resource
+static CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) {
+  const Type *Ty = VD->getType()->getPointeeOrArrayElementType();
+  assert(Ty && "Resource must have an element type.");
+
+  if (const auto *TheBuiltinTy = dyn_cast<BuiltinType>(Ty))
+    return nullptr;
+
+  CXXRecordDecl *TheRecordDecl = Ty->getAsCXXRecordDecl();
+  assert(TheRecordDecl && "Resource should have a resource type declaration.");
+  return TheRecordDecl;
+}
+
+static void setResourceClassFlagsFromDeclResourceClass(
+    RegisterBindingFlags &Flags, llvm::hlsl::ResourceClass DeclResourceClass) {
+  switch (DeclResourceClass) {
+  case llvm::hlsl::ResourceClass::SRV:
+    Flags.SRV = true;
+    break;
+  case llvm::hlsl::ResourceClass::UAV:
+    Flags.UAV = true;
+    break;
+  case llvm::hlsl::ResourceClass::CBuffer:
+    Flags.CBV = true;
+    break;
+  case llvm::hlsl::ResourceClass::Sampler:
+    Flags.Sampler = true;
+    break;
+  }
+}
+
+template <typename T>
+static const T *
+getSpecifiedHLSLAttrFromVarDeclOrRecordDecl(VarDecl *VD,
+                                            RecordDecl *TheRecordDecl) {
+  if (VD) {
+    TheRecordDecl = getRecordDeclFromVarDecl(VD);
+    if (!TheRecordDecl)
+      return nullptr;
+  }
+
+  // make a lambda that checks if the decl has the specified attr,
+  // and if not, loops over the field members and checks for the
+  // specified attribute
+  auto f = [](RecordDecl *TheRecordDecl) -> const T * {
+    for (auto *FD : TheRecordDecl->fields()) {
+      const T *Attr = FD->getAttr<T>();
+      if (Attr)
+        return Attr;
+    }
+    return nullptr;
+  };
+
+  if (TheRecordDecl) {
+    // if the member's base type is a ClassTemplateSpecializationDecl,
+    // check if it has a member handle with a resource class attr
+    // this is necessary while resources like RWBuffer are defined externally
+    if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(TheRecordDecl)) {
+      auto TheCXXRecordDecl =
+          TDecl->getSpecializedTemplate()->getTemplatedDecl();
+      TheCXXRecordDecl = TheCXXRecordDecl->getCanonicalDecl();
+
+      return f(TheCXXRecordDecl);
+    }
+
+    return f(TheRecordDecl);
+  }
+  llvm_unreachable("TheRecordDecl should not be null");
+  return nullptr;
+}
+
+static void setFlagsFromType(QualType TheQualTy, RegisterBindingFlags &Flags) {
+  // if the member's type is a numeric type, set the ContainsNumeric flag
+  if (TheQualTy->isIntegralOrEnumerationType() || TheQualTy->isFloatingType()) {
+    Flags.ContainsNumeric = true;
+    return;
+  }
+
+  // otherwise, if the member's base type is not a record type, return
+  const clang::Type *TheBaseType = TheQualTy.getTypePtr();
+  while (TheBaseType->isArrayType())
+    TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
+
+  const RecordType *TheRecordTy = TheBaseType->getAs<RecordType>();
+  if (!TheRecordTy)
+    return;
+
+  RecordDecl *SubRecordDecl = TheRecordTy->getDecl();
+  bool resClassSet = false;
+  const HLSLResourceClassAttr *Attr =
+      getSpecifiedHLSLAttrFromVarDeclOrRecordDecl<HLSLResourceClassAttr>(
+          nullptr, SubRecordDecl);
+  // find the attr if it's on the member (the handle) of the resource
+  if (Attr) {
+    llvm::hlsl::ResourceClass DeclResourceClass = Attr->getResourceClass();
+    setResourceClassFlagsFromDeclResourceClass(Flags, DeclResourceClass);
+    resClassSet = true;
+  }
+  // otherwise, check if the member of the UDT itself has a resource class attr
+  else if (const auto *Attr = SubRecordDecl->getAttr<HLSLResourceClassAttr>()) {
+    llvm::hlsl::ResourceClass DeclResourceClass = Attr->getResourceClass();
+    setResourceClassFlagsFromDeclResourceClass(Flags, DeclResourceClass);
+    resClassSet = true;
+  }
+  // recurse if there are more fields to analyze
+  if (!resClassSet) {
+    for (auto Field : SubRecordDecl->fields()) {
+      setFlagsFromType(Field->getType(), Flags);
+    }
+  }
+}
+
+static void setResourceClassFlagsFromRecordDecl(RegisterBindingFlags &Flags,
+                                                const RecordDecl *RD) {
+  if (!RD)
+    return;
+
+  if (RD->isCompleteDefinition()) {
+    for (auto Field : RD->fields()) {
+      QualType T = Field->getType();
+      setFlagsFromType(T, Flags);
+    }
+  }
+}
+
+static RegisterBindingFlags HLSLFillRegisterBindingFlags(Sema &S,
+                                                         Decl *TheDecl) {
+
+  // Cbuffers and Tbuffers are HLSLBufferDecl types
+  HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl);
+  // Samplers, UAVs, and SRVs are VarDecl types
+  VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl);
+
+  assert(((TheVarDecl && !CBufferOrTBuffer) ||
+          (!TheVarDecl && CBufferOrTBuffer)) &&
+         "either TheVarDecl or CBufferOrTBuffer should be set");
+
+  RegisterBindingFlags Flags;
+
+  // check if the decl type is groupshared
+  if (TheDecl->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
+    Flags.Other = true;
+    return Flags;
+  }
+
+  if (!isDeclaredWithinCOrTBuffer(TheDecl)) {
+    // make sure the type is a basic / numeric type
+    if (TheVarDecl) {
+      QualType TheQualTy = TheVarDecl->getType();
+      // a numeric variable or an array of numeric variables
+      // will inevitably end up in $Globals buffer
+      const clang::Type *TheBaseType = TheQualTy.getTypePtr();
+      while (TheBaseType->isArrayType())
+        TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
+      if (TheBaseType->isIntegralType(S.getASTContext()) ||
+          TheBaseType->isFloatingType())
+        Flags.DefaultGlobals = true;
+    }
+  }
+
+  if (CBufferOrTBuffer) {
+    Flags.Resource = true;
+    if (CBufferOrTBuffer->isCBuffer())
+      Flags.CBV = true;
+    else
+      Flags.SRV = true;
+  } else if (TheVarDecl) {
+    const HLSLResourceClassAttr *resClassAttr =
+        getSpecifiedHLSLAttrFromVarDeclOrRecordDecl<HLSLResourceClassAttr>(
+            TheVarDecl, nullptr);
+    const clang::Type *TheBaseType = TheVarDecl->getType().getTypePtr();
+    while (TheBaseType->isArrayType())
+      TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
----------------
hekota wrote:

`TheBaseType` is not used until line `679: if (TheBaseType->isArithmeticType())`. Can you move this closer to where it is used?

https://github.com/llvm/llvm-project/pull/97103


More information about the cfe-commits mailing list