[clang] Implement resource binding type prefix mismatch diagnostic infrastructure (PR #97103)

Joshua Batista via cfe-commits cfe-commits at lists.llvm.org
Thu Aug 15 11:29:13 PDT 2024


================
@@ -459,7 +467,506 @@ void SemaHLSL::handleResourceClassAttr(Decl *D, const ParsedAttr &AL) {
   D->addAttr(HLSLResourceClassAttr::Create(getASTContext(), RC, ArgLoc));
 }
 
-void SemaHLSL::handleResourceBindingAttr(Decl *D, const ParsedAttr &AL) {
+struct RegisterBindingFlags {
+  bool Resource = false;
+  bool UDT = false;
+  bool Other = false;
+  bool Basic = false;
+
+  bool SRV = false;
+  bool UAV = false;
+  bool CBV = false;
+  bool Sampler = false;
+
+  bool ContainsNumeric = false;
+  bool DefaultGlobals = false;
+};
+
+bool isDeclaredWithinCOrTBuffer(const Decl *TheDecl) {
+  if (!TheDecl)
+    return false;
+
+  // Traverse up the parent contexts
+  const DeclContext *context = TheDecl->getDeclContext();
+  if (isa<HLSLBufferDecl>(context)) {
+    return true;
+  }
+
+  return false;
+}
+
+const CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) {
+  const Type *Ty = VD->getType()->getPointeeOrArrayElementType();
+  assert(Ty && "Resource class must have an element type.");
+
+  if (const auto *TheBuiltinTy = dyn_cast<BuiltinType>(Ty))
+    return nullptr;
+
+  const CXXRecordDecl *TheRecordDecl = Ty->getAsCXXRecordDecl();
+  assert(TheRecordDecl &&
+         "Resource class should have a resource type declaration.");
+
+  if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(TheRecordDecl))
+    TheRecordDecl = TDecl->getSpecializedTemplate()->getTemplatedDecl();
+  TheRecordDecl = TheRecordDecl->getCanonicalDecl();
+  return TheRecordDecl;
+}
+
+const HLSLResourceClassAttr *
+getHLSLResourceClassAttrFromEitherDecl(VarDecl *VD,
+                                       HLSLBufferDecl *CBufferOrTBuffer) {
+
+  if (VD) {
+    const CXXRecordDecl *TheRecordDecl = getRecordDeclFromVarDecl(VD);
+    if (!TheRecordDecl)
+      return nullptr;
+
+    // the resource class attr could be on the record decl itself or on one of
+    // its fields (the resource handle, most commonly)
+    const auto *Attr = TheRecordDecl->getAttr<HLSLResourceClassAttr>();
+    if (!Attr) {
+      for (auto *FD : TheRecordDecl->fields()) {
+        Attr = FD->getAttr<HLSLResourceClassAttr>();
+        if (Attr)
+          break;
+      }
+    }
+    return Attr;
+  } else if (CBufferOrTBuffer) {
+    const auto *Attr = CBufferOrTBuffer->getAttr<HLSLResourceClassAttr>();
+    return Attr;
+  }
+  llvm_unreachable("one of the two conditions should be true.");
+  return nullptr;
+}
+
+const HLSLResourceAttr *
+getHLSLResourceAttrFromEitherDecl(VarDecl *VD,
+                                  HLSLBufferDecl *CBufferOrTBuffer) {
+
+  if (VD) {
+    const CXXRecordDecl *TheRecordDecl = getRecordDeclFromVarDecl(VD);
+    if (!TheRecordDecl)
+      return nullptr;
+
+    // the resource attr could be on the record decl itself or on one of
+    // its fields (the resource handle, most commonly)
+    const auto *Attr = TheRecordDecl->getAttr<HLSLResourceAttr>();
+    if (!Attr) {
+      for (auto *FD : TheRecordDecl->fields()) {
+        Attr = FD->getAttr<HLSLResourceAttr>();
+        if (Attr)
+          break;
+      }
+    }
+    return Attr;
+  } else if (CBufferOrTBuffer) {
+    const auto *Attr = CBufferOrTBuffer->getAttr<HLSLResourceAttr>();
+    return Attr;
+  }
+  llvm_unreachable("one of the two conditions should be true.");
+  return nullptr;
+}
+
+void traverseType(QualType TheQualTy, RegisterBindingFlags &Flags) {
+  // if the member's type is a numeric type, set the ContainsNumeric flag
+  if (TheQualTy->isIntegralOrEnumerationType() || TheQualTy->isFloatingType()) {
+    Flags.ContainsNumeric = true;
+    return;
+  }
+
+  // otherwise, if the member's base type is not a record type, return
+  const clang::Type *TheBaseType = TheQualTy.getTypePtr();
+  while (TheBaseType->isArrayType())
+    TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
+
+  const RecordType *TheRecordTy = TheBaseType->getAs<RecordType>();
+  if (!TheRecordTy)
+    return;
+
+  RecordDecl *SubRecordDecl = TheRecordTy->getDecl();
+  bool resClassSet = false;
+  // if the member's base type is a ClassTemplateSpecializationDecl,
+  // check if it has a resource class attr
+  if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(SubRecordDecl)) {
+    auto TheRecordDecl = TDecl->getSpecializedTemplate()->getTemplatedDecl();
+    TheRecordDecl = TheRecordDecl->getCanonicalDecl();
+    const auto *Attr = TheRecordDecl->getAttr<HLSLResourceClassAttr>();
+    if (!Attr) {
+      for (auto *FD : TheRecordDecl->fields()) {
+        Attr = FD->getAttr<HLSLResourceClassAttr>();
+        if (Attr)
+          break;
+      }
+    }
+    llvm::hlsl::ResourceClass DeclResourceClass = Attr->getResourceClass();
+    switch (DeclResourceClass) {
+    case llvm::hlsl::ResourceClass::SRV:
+      Flags.SRV = true;
+      break;
+    case llvm::hlsl::ResourceClass::UAV:
+      Flags.UAV = true;
+      break;
+    case llvm::hlsl::ResourceClass::CBuffer:
+      Flags.CBV = true;
+      break;
+    case llvm::hlsl::ResourceClass::Sampler:
+      Flags.Sampler = true;
+      break;
+    }
+    resClassSet = true;
+  }
+  // otherwise, check if the member has a resource class attr
+  else if (auto *Attr = SubRecordDecl->getAttr<HLSLResourceClassAttr>()) {
+    llvm::hlsl::ResourceClass DeclResourceClass = Attr->getResourceClass();
+    switch (DeclResourceClass) {
+    case llvm::hlsl::ResourceClass::SRV:
+      Flags.SRV = true;
+      break;
+    case llvm::hlsl::ResourceClass::UAV:
+      Flags.UAV = true;
+      break;
+    case llvm::hlsl::ResourceClass::CBuffer:
+      Flags.CBV = true;
+      break;
+    case llvm::hlsl::ResourceClass::Sampler:
+      Flags.Sampler = true;
+      break;
+    }
+    resClassSet = true;
+  }
+
+  if (!resClassSet) {
+    for (auto Field : SubRecordDecl->fields()) {
+      traverseType(Field->getType(), Flags);
+    }
+  }
+}
+
+void setResourceClassFlagsFromRecordDecl(RegisterBindingFlags &Flags,
+                                         const RecordDecl *RD) {
+  if (!RD)
+    return;
+
+  if (RD->isCompleteDefinition()) {
+    for (auto Field : RD->fields()) {
+      QualType T = Field->getType();
+      traverseType(T, Flags);
+    }
+  }
+}
+
+RegisterBindingFlags HLSLFillRegisterBindingFlags(Sema &S, Decl *TheDecl) {
+
+  // Cbuffers and Tbuffers are HLSLBufferDecl types
+  HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl);
+  // Samplers, UAVs, and SRVs are VarDecl types
+  VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl);
+
+  assert(((TheVarDecl && !CBufferOrTBuffer) ||
+          (!TheVarDecl && CBufferOrTBuffer)) &&
+         "either VD or CBufferOrTBuffer should be set");
+
+  RegisterBindingFlags Flags;
+
+  // check if the decl type is groupshared
+  if (TheDecl->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
+    Flags.Other = true;
+    return Flags;
+  }
+
+  if (!isDeclaredWithinCOrTBuffer(TheDecl)) {
+    // make sure the type is a basic / numeric type
+    if (TheVarDecl) {
+      QualType TheQualTy = TheVarDecl->getType();
+      // a numeric variable or an array of numeric variables
+      // will inevitably end up in $Globals buffer
+      const clang::Type *TheBaseType = TheQualTy.getTypePtr();
+      while (TheBaseType->isArrayType())
+        TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
+      if (TheBaseType->isIntegralType(S.getASTContext()) ||
+          TheBaseType->isFloatingType())
+        Flags.DefaultGlobals = true;
+    }
+  }
+
+  if (CBufferOrTBuffer) {
+    Flags.Resource = true;
+    if (CBufferOrTBuffer->isCBuffer())
+      Flags.CBV = true;
+    else
+      Flags.SRV = true;
+  } else if (TheVarDecl) {
+    const HLSLResourceClassAttr *resClassAttr =
+        getHLSLResourceClassAttrFromEitherDecl(TheVarDecl, CBufferOrTBuffer);
+    const clang::Type *TheBaseType = TheVarDecl->getType().getTypePtr();
+    while (TheBaseType->isArrayType())
+      TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
+
+    if (resClassAttr) {
+      llvm::hlsl::ResourceClass DeclResourceClass =
+          resClassAttr->getResourceClass();
+      Flags.Resource = true;
+      switch (DeclResourceClass) {
+      case llvm::hlsl::ResourceClass::SRV:
+        Flags.SRV = true;
+        break;
+      case llvm::hlsl::ResourceClass::UAV:
+        Flags.UAV = true;
+        break;
+      case llvm::hlsl::ResourceClass::CBuffer:
+        Flags.CBV = true;
+        break;
+      case llvm::hlsl::ResourceClass::Sampler:
+        Flags.Sampler = true;
+        break;
+      }
+    } else {
+      if (TheBaseType->isArithmeticType())
+        Flags.Basic = true;
+      else if (TheBaseType->isRecordType()) {
+        Flags.UDT = true;
+        const RecordType *TheRecordTy = TheBaseType->getAs<RecordType>();
+        assert(TheRecordTy && "The Qual Type should be Record Type");
+        const RecordDecl *TheRecordDecl = TheRecordTy->getDecl();
+        // recurse through members, set appropriate resource class flags.
+        setResourceClassFlagsFromRecordDecl(Flags, TheRecordDecl);
+      } else
+        Flags.Other = true;
+    }
+  }
+  return Flags;
+}
+
+enum RegisterType { SRV, UAV, CBuffer, Sampler, C, I };
+
+int getRegisterTypeIndex(StringRef Slot) {
+  switch (Slot[0]) {
+  case 't':
+  case 'T':
+    return RegisterType::SRV;
+  case 'u':
+  case 'U':
+    return RegisterType::UAV;
+  case 'b':
+  case 'B ':
+    return RegisterType::CBuffer;
+  case 's':
+  case 'S':
+    return RegisterType::Sampler;
+  case 'c':
+  case 'C':
+    return RegisterType::C;
+  case 'i':
+  case 'I':
+    return RegisterType::I;
+  default:
+    llvm_unreachable("invalid register type");
+  }
+}
+
+static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
+                                                StringRef &Slot) {
+  // make sure that there are no tworegister annotations
+  // applied to the decl with the same register type
+  bool RegisterTypesDetected[6] = {false};
+  RegisterTypesDetected[getRegisterTypeIndex(Slot)] = true;
+
+  for (auto it = TheDecl->attr_begin(); it != TheDecl->attr_end(); ++it) {
+    if (HLSLResourceBindingAttr *attr =
+            dyn_cast<HLSLResourceBindingAttr>(*it)) {
+
+      int registerTypeIndex = getRegisterTypeIndex(attr->getSlot());
+      if (RegisterTypesDetected[registerTypeIndex]) {
+        S.Diag(TheDecl->getLocation(),
+               diag::err_hlsl_duplicate_register_annotation)
+            << registerTypeIndex;
+      } else {
+        RegisterTypesDetected[registerTypeIndex] = true;
+      }
+    }
+  }
+}
+
+std::string getHLSLResourceTypeStr(Sema &S, Decl *TheDecl) {
+  VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl);
+  HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl);
+
+  if (TheVarDecl) {
+    QualType TheQualTy = TheVarDecl->getType();
+    PrintingPolicy PP = S.getPrintingPolicy();
+    return QualType::getAsString(TheQualTy.split(), PP);
+  } else {
+    return CBufferOrTBuffer->isCBuffer() ? "cbuffer" : "tbuffer";
+  }
+}
+
+static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
+                                          Decl *TheDecl, StringRef &Slot) {
+
+  // Samplers, UAVs, and SRVs are VarDecl types
+  VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl);
+  // Cbuffers and Tbuffers are HLSLBufferDecl types
+  HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl);
+
+  // exactly one of these two types should be set
+  assert(((TheVarDecl && !CBufferOrTBuffer) ||
+          (!TheVarDecl && CBufferOrTBuffer)) &&
+         "either TheVarDecl or CBufferOrTBuffer should be set");
+
+  RegisterBindingFlags Flags = HLSLFillRegisterBindingFlags(S, TheDecl);
+  assert((int)Flags.Other + (int)Flags.Resource + (int)Flags.Basic +
+                 (int)Flags.UDT ==
+             1 &&
+         "only one resource analysis result should be expected");
+
+  int regType = getRegisterTypeIndex(Slot);
+
+  // first, if "other" is set, emit an error
+  if (Flags.Other) {
+    if (regType == RegisterType::I) {
+      S.Diag(TheDecl->getLocation(),
+             diag::warn_hlsl_deprecated_register_type_i);
+      return;
+    }
+    S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << regType;
+    return;
+  }
+
+  // next, if multiple register annotations exist, check that none conflict.
+  ValidateMultipleRegisterAnnotations(S, TheDecl, Slot);
+
+  // next, if resource is set, make sure the register type in the register
+  // annotation is compatible with the variable's resource type.
+  if (Flags.Resource) {
+    if (regType == RegisterType::I) {
+      S.Diag(TheDecl->getLocation(),
+             diag::warn_hlsl_deprecated_register_type_i);
+      return;
+    }
+    const HLSLResourceAttr *resAttr =
+        getHLSLResourceAttrFromEitherDecl(TheVarDecl, CBufferOrTBuffer);
----------------
bob80905 wrote:

The assert doesn't fire because the assert is in the `Resource` case, and none of the cases in the udt test file are `Resource`s, rather they're udts, so the `UDT` block is relevant.
As for your first statement, yes, I expect any resource type should have the resource attribute. Despite the fact we will soon be removing this, at this point in time every resource should have the resource attribute, so we should verify that.

https://github.com/llvm/llvm-project/pull/97103


More information about the cfe-commits mailing list