[clang] Implement resource binding type prefix mismatch diagnostic infrastructure (PR #97103)

Joshua Batista via cfe-commits cfe-commits at lists.llvm.org
Thu Aug 15 13:07:05 PDT 2024


================
@@ -459,7 +467,506 @@ void SemaHLSL::handleResourceClassAttr(Decl *D, const ParsedAttr &AL) {
   D->addAttr(HLSLResourceClassAttr::Create(getASTContext(), RC, ArgLoc));
 }
 
-void SemaHLSL::handleResourceBindingAttr(Decl *D, const ParsedAttr &AL) {
+struct RegisterBindingFlags {
+  bool Resource = false;
+  bool UDT = false;
+  bool Other = false;
+  bool Basic = false;
+
+  bool SRV = false;
+  bool UAV = false;
+  bool CBV = false;
+  bool Sampler = false;
+
+  bool ContainsNumeric = false;
+  bool DefaultGlobals = false;
+};
+
+bool isDeclaredWithinCOrTBuffer(const Decl *TheDecl) {
+  if (!TheDecl)
+    return false;
+
+  // Traverse up the parent contexts
+  const DeclContext *context = TheDecl->getDeclContext();
+  if (isa<HLSLBufferDecl>(context)) {
+    return true;
+  }
+
+  return false;
+}
+
+const CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) {
+  const Type *Ty = VD->getType()->getPointeeOrArrayElementType();
+  assert(Ty && "Resource class must have an element type.");
+
+  if (const auto *TheBuiltinTy = dyn_cast<BuiltinType>(Ty))
+    return nullptr;
+
+  const CXXRecordDecl *TheRecordDecl = Ty->getAsCXXRecordDecl();
+  assert(TheRecordDecl &&
+         "Resource class should have a resource type declaration.");
+
+  if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(TheRecordDecl))
+    TheRecordDecl = TDecl->getSpecializedTemplate()->getTemplatedDecl();
+  TheRecordDecl = TheRecordDecl->getCanonicalDecl();
+  return TheRecordDecl;
+}
+
+const HLSLResourceClassAttr *
+getHLSLResourceClassAttrFromEitherDecl(VarDecl *VD,
+                                       HLSLBufferDecl *CBufferOrTBuffer) {
+
+  if (VD) {
+    const CXXRecordDecl *TheRecordDecl = getRecordDeclFromVarDecl(VD);
+    if (!TheRecordDecl)
+      return nullptr;
+
+    // the resource class attr could be on the record decl itself or on one of
+    // its fields (the resource handle, most commonly)
+    const auto *Attr = TheRecordDecl->getAttr<HLSLResourceClassAttr>();
+    if (!Attr) {
+      for (auto *FD : TheRecordDecl->fields()) {
+        Attr = FD->getAttr<HLSLResourceClassAttr>();
+        if (Attr)
+          break;
+      }
+    }
+    return Attr;
+  } else if (CBufferOrTBuffer) {
+    const auto *Attr = CBufferOrTBuffer->getAttr<HLSLResourceClassAttr>();
+    return Attr;
+  }
+  llvm_unreachable("one of the two conditions should be true.");
+  return nullptr;
+}
+
+const HLSLResourceAttr *
+getHLSLResourceAttrFromEitherDecl(VarDecl *VD,
+                                  HLSLBufferDecl *CBufferOrTBuffer) {
+
+  if (VD) {
+    const CXXRecordDecl *TheRecordDecl = getRecordDeclFromVarDecl(VD);
+    if (!TheRecordDecl)
+      return nullptr;
+
+    // the resource attr could be on the record decl itself or on one of
+    // its fields (the resource handle, most commonly)
+    const auto *Attr = TheRecordDecl->getAttr<HLSLResourceAttr>();
+    if (!Attr) {
+      for (auto *FD : TheRecordDecl->fields()) {
+        Attr = FD->getAttr<HLSLResourceAttr>();
+        if (Attr)
+          break;
+      }
+    }
+    return Attr;
+  } else if (CBufferOrTBuffer) {
+    const auto *Attr = CBufferOrTBuffer->getAttr<HLSLResourceAttr>();
+    return Attr;
+  }
+  llvm_unreachable("one of the two conditions should be true.");
+  return nullptr;
+}
+
+void traverseType(QualType TheQualTy, RegisterBindingFlags &Flags) {
+  // if the member's type is a numeric type, set the ContainsNumeric flag
+  if (TheQualTy->isIntegralOrEnumerationType() || TheQualTy->isFloatingType()) {
+    Flags.ContainsNumeric = true;
+    return;
+  }
+
+  // otherwise, if the member's base type is not a record type, return
+  const clang::Type *TheBaseType = TheQualTy.getTypePtr();
+  while (TheBaseType->isArrayType())
+    TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
+
+  const RecordType *TheRecordTy = TheBaseType->getAs<RecordType>();
+  if (!TheRecordTy)
+    return;
+
+  RecordDecl *SubRecordDecl = TheRecordTy->getDecl();
+  bool resClassSet = false;
+  // if the member's base type is a ClassTemplateSpecializationDecl,
+  // check if it has a resource class attr
+  if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(SubRecordDecl)) {
+    auto TheRecordDecl = TDecl->getSpecializedTemplate()->getTemplatedDecl();
+    TheRecordDecl = TheRecordDecl->getCanonicalDecl();
+    const auto *Attr = TheRecordDecl->getAttr<HLSLResourceClassAttr>();
+    if (!Attr) {
+      for (auto *FD : TheRecordDecl->fields()) {
+        Attr = FD->getAttr<HLSLResourceClassAttr>();
+        if (Attr)
+          break;
+      }
+    }
+    llvm::hlsl::ResourceClass DeclResourceClass = Attr->getResourceClass();
+    switch (DeclResourceClass) {
+    case llvm::hlsl::ResourceClass::SRV:
+      Flags.SRV = true;
+      break;
+    case llvm::hlsl::ResourceClass::UAV:
+      Flags.UAV = true;
+      break;
+    case llvm::hlsl::ResourceClass::CBuffer:
+      Flags.CBV = true;
+      break;
+    case llvm::hlsl::ResourceClass::Sampler:
+      Flags.Sampler = true;
+      break;
+    }
+    resClassSet = true;
+  }
+  // otherwise, check if the member has a resource class attr
+  else if (auto *Attr = SubRecordDecl->getAttr<HLSLResourceClassAttr>()) {
+    llvm::hlsl::ResourceClass DeclResourceClass = Attr->getResourceClass();
+    switch (DeclResourceClass) {
+    case llvm::hlsl::ResourceClass::SRV:
+      Flags.SRV = true;
+      break;
+    case llvm::hlsl::ResourceClass::UAV:
+      Flags.UAV = true;
+      break;
+    case llvm::hlsl::ResourceClass::CBuffer:
+      Flags.CBV = true;
+      break;
+    case llvm::hlsl::ResourceClass::Sampler:
+      Flags.Sampler = true;
+      break;
+    }
+    resClassSet = true;
+  }
+
+  if (!resClassSet) {
+    for (auto Field : SubRecordDecl->fields()) {
+      traverseType(Field->getType(), Flags);
+    }
+  }
+}
+
+void setResourceClassFlagsFromRecordDecl(RegisterBindingFlags &Flags,
+                                         const RecordDecl *RD) {
+  if (!RD)
+    return;
+
+  if (RD->isCompleteDefinition()) {
+    for (auto Field : RD->fields()) {
+      QualType T = Field->getType();
+      traverseType(T, Flags);
+    }
+  }
+}
+
+RegisterBindingFlags HLSLFillRegisterBindingFlags(Sema &S, Decl *TheDecl) {
+
+  // Cbuffers and Tbuffers are HLSLBufferDecl types
+  HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl);
+  // Samplers, UAVs, and SRVs are VarDecl types
+  VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl);
+
+  assert(((TheVarDecl && !CBufferOrTBuffer) ||
+          (!TheVarDecl && CBufferOrTBuffer)) &&
+         "either VD or CBufferOrTBuffer should be set");
+
+  RegisterBindingFlags Flags;
+
+  // check if the decl type is groupshared
+  if (TheDecl->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) {
+    Flags.Other = true;
+    return Flags;
+  }
+
+  if (!isDeclaredWithinCOrTBuffer(TheDecl)) {
+    // make sure the type is a basic / numeric type
+    if (TheVarDecl) {
+      QualType TheQualTy = TheVarDecl->getType();
+      // a numeric variable or an array of numeric variables
+      // will inevitably end up in $Globals buffer
+      const clang::Type *TheBaseType = TheQualTy.getTypePtr();
+      while (TheBaseType->isArrayType())
+        TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
+      if (TheBaseType->isIntegralType(S.getASTContext()) ||
+          TheBaseType->isFloatingType())
+        Flags.DefaultGlobals = true;
+    }
+  }
+
+  if (CBufferOrTBuffer) {
+    Flags.Resource = true;
+    if (CBufferOrTBuffer->isCBuffer())
+      Flags.CBV = true;
+    else
+      Flags.SRV = true;
+  } else if (TheVarDecl) {
+    const HLSLResourceClassAttr *resClassAttr =
+        getHLSLResourceClassAttrFromEitherDecl(TheVarDecl, CBufferOrTBuffer);
+    const clang::Type *TheBaseType = TheVarDecl->getType().getTypePtr();
+    while (TheBaseType->isArrayType())
+      TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual();
+
+    if (resClassAttr) {
+      llvm::hlsl::ResourceClass DeclResourceClass =
+          resClassAttr->getResourceClass();
+      Flags.Resource = true;
+      switch (DeclResourceClass) {
+      case llvm::hlsl::ResourceClass::SRV:
+        Flags.SRV = true;
+        break;
+      case llvm::hlsl::ResourceClass::UAV:
+        Flags.UAV = true;
+        break;
+      case llvm::hlsl::ResourceClass::CBuffer:
+        Flags.CBV = true;
+        break;
+      case llvm::hlsl::ResourceClass::Sampler:
+        Flags.Sampler = true;
+        break;
+      }
+    } else {
+      if (TheBaseType->isArithmeticType())
+        Flags.Basic = true;
+      else if (TheBaseType->isRecordType()) {
+        Flags.UDT = true;
+        const RecordType *TheRecordTy = TheBaseType->getAs<RecordType>();
+        assert(TheRecordTy && "The Qual Type should be Record Type");
+        const RecordDecl *TheRecordDecl = TheRecordTy->getDecl();
+        // recurse through members, set appropriate resource class flags.
+        setResourceClassFlagsFromRecordDecl(Flags, TheRecordDecl);
+      } else
+        Flags.Other = true;
+    }
+  }
+  return Flags;
+}
+
+enum RegisterType { SRV, UAV, CBuffer, Sampler, C, I };
+
+int getRegisterTypeIndex(StringRef Slot) {
+  switch (Slot[0]) {
+  case 't':
+  case 'T':
+    return RegisterType::SRV;
+  case 'u':
+  case 'U':
+    return RegisterType::UAV;
+  case 'b':
+  case 'B ':
+    return RegisterType::CBuffer;
+  case 's':
+  case 'S':
+    return RegisterType::Sampler;
+  case 'c':
+  case 'C':
+    return RegisterType::C;
+  case 'i':
+  case 'I':
+    return RegisterType::I;
+  default:
+    llvm_unreachable("invalid register type");
+  }
+}
+
+static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl,
+                                                StringRef &Slot) {
+  // make sure that there are no tworegister annotations
+  // applied to the decl with the same register type
+  bool RegisterTypesDetected[6] = {false};
+  RegisterTypesDetected[getRegisterTypeIndex(Slot)] = true;
+
+  for (auto it = TheDecl->attr_begin(); it != TheDecl->attr_end(); ++it) {
+    if (HLSLResourceBindingAttr *attr =
+            dyn_cast<HLSLResourceBindingAttr>(*it)) {
+
+      int registerTypeIndex = getRegisterTypeIndex(attr->getSlot());
+      if (RegisterTypesDetected[registerTypeIndex]) {
+        S.Diag(TheDecl->getLocation(),
+               diag::err_hlsl_duplicate_register_annotation)
+            << registerTypeIndex;
+      } else {
+        RegisterTypesDetected[registerTypeIndex] = true;
+      }
+    }
+  }
+}
+
+std::string getHLSLResourceTypeStr(Sema &S, Decl *TheDecl) {
+  VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl);
+  HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl);
+
+  if (TheVarDecl) {
+    QualType TheQualTy = TheVarDecl->getType();
+    PrintingPolicy PP = S.getPrintingPolicy();
+    return QualType::getAsString(TheQualTy.split(), PP);
+  } else {
+    return CBufferOrTBuffer->isCBuffer() ? "cbuffer" : "tbuffer";
+  }
+}
+
+static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc,
+                                          Decl *TheDecl, StringRef &Slot) {
+
+  // Samplers, UAVs, and SRVs are VarDecl types
+  VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl);
+  // Cbuffers and Tbuffers are HLSLBufferDecl types
+  HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl);
+
+  // exactly one of these two types should be set
+  assert(((TheVarDecl && !CBufferOrTBuffer) ||
+          (!TheVarDecl && CBufferOrTBuffer)) &&
+         "either TheVarDecl or CBufferOrTBuffer should be set");
+
+  RegisterBindingFlags Flags = HLSLFillRegisterBindingFlags(S, TheDecl);
+  assert((int)Flags.Other + (int)Flags.Resource + (int)Flags.Basic +
+                 (int)Flags.UDT ==
+             1 &&
+         "only one resource analysis result should be expected");
+
+  int regType = getRegisterTypeIndex(Slot);
+
+  // first, if "other" is set, emit an error
+  if (Flags.Other) {
+    if (regType == RegisterType::I) {
+      S.Diag(TheDecl->getLocation(),
+             diag::warn_hlsl_deprecated_register_type_i);
+      return;
+    }
+    S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << regType;
+    return;
+  }
+
+  // next, if multiple register annotations exist, check that none conflict.
+  ValidateMultipleRegisterAnnotations(S, TheDecl, Slot);
+
+  // next, if resource is set, make sure the register type in the register
+  // annotation is compatible with the variable's resource type.
+  if (Flags.Resource) {
+    if (regType == RegisterType::I) {
+      S.Diag(TheDecl->getLocation(),
+             diag::warn_hlsl_deprecated_register_type_i);
+      return;
+    }
+    const HLSLResourceAttr *resAttr =
+        getHLSLResourceAttrFromEitherDecl(TheVarDecl, CBufferOrTBuffer);
+    const HLSLResourceClassAttr *resClassAttr =
+        getHLSLResourceClassAttrFromEitherDecl(TheVarDecl, CBufferOrTBuffer);
+    assert(resAttr && resClassAttr &&
+           "any decl that set the resource flag on analysis should "
+           "have a resource attribute and resource class attribute attached.");
+    const llvm::hlsl::ResourceClass DeclResourceClass =
+        resClassAttr->getResourceClass();
+
+    switch (DeclResourceClass) {
+    case llvm::hlsl::ResourceClass::SRV:
+      if (regType != RegisterType::SRV)
+        S.Diag(TheDecl->getLocation(), diag::err_hlsl_binding_type_mismatch)
+            << regType;
+      break;
+    case llvm::hlsl::ResourceClass::UAV:
+      if (regType != RegisterType::UAV)
+        S.Diag(TheDecl->getLocation(), diag::err_hlsl_binding_type_mismatch)
+            << regType;
+      break;
+    case llvm::hlsl::ResourceClass::CBuffer:
+      if (regType != RegisterType::CBuffer)
+        S.Diag(TheDecl->getLocation(), diag::err_hlsl_binding_type_mismatch)
+            << regType;
+      break;
+    case llvm::hlsl::ResourceClass::Sampler:
+      if (regType != RegisterType::Sampler)
+        S.Diag(TheDecl->getLocation(), diag::err_hlsl_binding_type_mismatch)
+            << regType;
+      break;
+    }
+    return;
+  }
+
+  // next, handle diagnostics for when the "basic" flag is set,
+  // including the legacy "i" and "b" register types.
+  if (Flags.Basic) {
+    if (Flags.DefaultGlobals) {
+      if (regType == RegisterType::CBuffer)
+        S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b);
+      else if (regType == RegisterType::I)
+        S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_i);
----------------
bob80905 wrote:

Because `b` is used for CBuffers, I can't immediately error and return when detected, but I think I can do it for `i`, so I've done that.

https://github.com/llvm/llvm-project/pull/97103


More information about the cfe-commits mailing list