[clang] Implement resource binding type prefix mismatch flag setting logic (PR #97103)

Joshua Batista via cfe-commits cfe-commits at lists.llvm.org
Mon Jul 8 18:02:52 PDT 2024


================
@@ -437,7 +460,406 @@ void SemaHLSL::handleShaderAttr(Decl *D, const ParsedAttr &AL) {
     D->addAttr(NewAttr);
 }
 
+struct register_binding_flags {
+  bool resource = false;
+  bool udt = false;
+  bool other = false;
+  bool basic = false;
+
+  bool srv = false;
+  bool uav = false;
+  bool cbv = false;
+  bool sampler = false;
+
+  bool contains_numeric = false;
+  bool default_globals = false;
+};
+
+bool isDeclaredWithinCOrTBuffer(const Decl *decl) {
+  if (!decl)
+    return false;
+
+  // Traverse up the parent contexts
+  const DeclContext *context = decl->getDeclContext();
+  while (context) {
+    if (isa<HLSLBufferDecl>(context)) {
+      return true;
+    }
+    context = context->getParent();
+  }
+
+  return false;
+}
+
+const CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *SamplerUAVOrSRV) {
+  const Type *Ty = SamplerUAVOrSRV->getType()->getPointeeOrArrayElementType();
+  if (!Ty)
+    llvm_unreachable("Resource class must have an element type.");
+
+  if (const BuiltinType *BTy = dyn_cast<BuiltinType>(Ty)) {
+    return nullptr;
+  }
+
+  const CXXRecordDecl *TheRecordDecl = Ty->getAsCXXRecordDecl();
+  if (!TheRecordDecl)
+    llvm_unreachable("Resource class should have a resource type declaration.");
+
+  if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(TheRecordDecl))
+    TheRecordDecl = TDecl->getSpecializedTemplate()->getTemplatedDecl();
+  TheRecordDecl = TheRecordDecl->getCanonicalDecl();
+  return TheRecordDecl;
+}
+
+const HLSLResourceAttr *
+getHLSLResourceAttrFromEitherDecl(VarDecl *SamplerUAVOrSRV,
+                                  HLSLBufferDecl *CBufferOrTBuffer) {
+
+  if (SamplerUAVOrSRV) {
+    const CXXRecordDecl *TheRecordDecl =
+        getRecordDeclFromVarDecl(SamplerUAVOrSRV);
+    if (!TheRecordDecl)
+      return nullptr;
+    const auto *Attr = TheRecordDecl->getAttr<HLSLResourceAttr>();
+    return Attr;
+  } else if (CBufferOrTBuffer) {
+    const auto *Attr = CBufferOrTBuffer->getAttr<HLSLResourceAttr>();
+    return Attr;
+  }
+  llvm_unreachable("one of the two conditions should be true.");
+  return nullptr;
+}
+
+void traverseType(QualType T, register_binding_flags &r) {
+  if (T->isIntegralOrEnumerationType() || T->isFloatingType()) {
+    r.contains_numeric = true;
+    return;
+  } else if (const RecordType *RT = T->getAs<RecordType>()) {
+    RecordDecl *SubRD = RT->getDecl();
+    if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(SubRD)) {
+      auto TheRecordDecl = TDecl->getSpecializedTemplate()->getTemplatedDecl();
+      TheRecordDecl = TheRecordDecl->getCanonicalDecl();
+      const auto *Attr = TheRecordDecl->getAttr<HLSLResourceAttr>();
+      llvm::hlsl::ResourceClass DeclResourceClass = Attr->getResourceClass();
+      switch (DeclResourceClass) {
+      case llvm::hlsl::ResourceClass::SRV: {
+        r.srv = true;
+        break;
+      }
+      case llvm::hlsl::ResourceClass::UAV: {
+        r.uav = true;
+        break;
+      }
+      case llvm::hlsl::ResourceClass::CBuffer: {
+        r.cbv = true;
+        break;
+      }
+      case llvm::hlsl::ResourceClass::Sampler: {
+        r.sampler = true;
+        break;
+      }
+      }
+    }
+
+    else if (SubRD->isCompleteDefinition()) {
+      for (auto Field : SubRD->fields()) {
+        QualType T = Field->getType();
+        traverseType(T, r);
+      }
+    }
+  }
+}
+
+void setResourceClassFlagsFromRecordDecl(register_binding_flags &r,
+                                         const RecordDecl *RD) {
+  if (!RD)
+    return;
+
+  if (RD->isCompleteDefinition()) {
+    for (auto Field : RD->fields()) {
+      QualType T = Field->getType();
+      traverseType(T, r);
+    }
+  }
+}
+
+register_binding_flags HLSLFillRegisterBindingFlags(Sema &S, Decl *D) {
+  register_binding_flags r;
+  if (!isDeclaredWithinCOrTBuffer(D)) {
+    // make sure the type is a basic / numeric type
+    if (VarDecl *v = dyn_cast<VarDecl>(D)) {
+      QualType t = v->getType();
+      // a numeric variable will inevitably end up in $Globals buffer
+      if (t->isIntegralType(S.getASTContext()) || t->isFloatingType())
+        r.default_globals = true;
+    }
+  }
+  // Cbuffers and Tbuffers are HLSLBufferDecl types
+  HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D);
+  // Samplers, UAVs, and SRVs are VarDecl types
+  VarDecl *SamplerUAVOrSRV = dyn_cast<VarDecl>(D);
+
+  if (CBufferOrTBuffer) {
+    r.resource = true;
+    if (CBufferOrTBuffer->isCBuffer())
+      r.cbv = true;
+    else
+      r.srv = true;
+  } else if (SamplerUAVOrSRV) {
+    const HLSLResourceAttr *res_attr =
+        getHLSLResourceAttrFromEitherDecl(SamplerUAVOrSRV, CBufferOrTBuffer);
+    if (res_attr) {
+      llvm::hlsl::ResourceClass DeclResourceClass =
+          res_attr->getResourceClass();
+      r.resource = true;
+      switch (DeclResourceClass) {
+      case llvm::hlsl::ResourceClass::SRV: {
+        r.srv = true;
+        break;
+      }
+      case llvm::hlsl::ResourceClass::UAV: {
+        r.uav = true;
+        break;
+      }
+      case llvm::hlsl::ResourceClass::CBuffer: {
+        r.cbv = true;
+        break;
+      }
+      case llvm::hlsl::ResourceClass::Sampler: {
+        r.sampler = true;
+        break;
+      }
+      }
+    } else {
+      if (SamplerUAVOrSRV->getType()->isBuiltinType())
+        r.basic = true;
+      else if (SamplerUAVOrSRV->getType()->isAggregateType()) {
+        r.udt = true;
+        QualType VarType = SamplerUAVOrSRV->getType();
+        if (const RecordType *RT = VarType->getAs<RecordType>()) {
+          const RecordDecl *RD = RT->getDecl();
+          // recurse through members, set appropriate resource class flags.
+          setResourceClassFlagsFromRecordDecl(r, RD);
+        }
+      } else
+        r.other = true;
+    }
+  } else {
+    llvm_unreachable("unknown decl type");
+  }
+  return r;
+}
+
+static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *D,
+                                                StringRef &Slot) {
+  // make sure that there are no register annotations applied to the decl
+  // with the same register type but different numbers
+  std::unordered_map<char, std::set<char>>
+      s; // store unique register type + numbers
+  std::set<char> starting_set = {Slot[1]};
+  s.insert(std::make_pair(Slot[0], starting_set));
+  for (auto it = D->attr_begin(); it != D->attr_end(); ++it) {
+    if (HLSLResourceBindingAttr *attr =
+            dyn_cast<HLSLResourceBindingAttr>(*it)) {
+      std::string otherSlot(attr->getSlot().data());
+
+      // insert into hash map
+      if (s.find(otherSlot[0]) != s.end()) {
+        // if the register type is already in the map, insert the number
+        // into the set (if it's not already there
+        s[otherSlot[0]].insert(otherSlot[1]);
+      } else {
+        // if the register type is not in the map, insert it with the number
+        std::set<char> otherSet;
+        otherSet.insert(otherSlot[1]);
+        s.insert(std::make_pair(otherSlot[0], otherSet));
+      }
+    }
+  }
+
+  for (auto regType : s) {
+    if (regType.second.size() > 1) {
+      std::string regTypeStr(1, regType.first);
+      S.Diag(D->getLocation(), diag::err_hlsl_conflicting_register_annotations)
+          << regTypeStr;
+    }
+  }
+}
+
+static void DiagnoseHLSLResourceRegType(Sema &S, SourceLocation &ArgLoc,
+                                        Decl *D, StringRef &Slot) {
+
+  // Samplers, UAVs, and SRVs are VarDecl types
+  VarDecl *SamplerUAVOrSRV = dyn_cast<VarDecl>(D);
+  // Cbuffers and Tbuffers are HLSLBufferDecl types
+  HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D);
+
+  // exactly one of these two types should be set
+  if (!SamplerUAVOrSRV && !CBufferOrTBuffer)
+    return;
+  if (SamplerUAVOrSRV && CBufferOrTBuffer)
+    return;
+
+  register_binding_flags f = HLSLFillRegisterBindingFlags(S, D);
+  assert((int)f.other + (int)f.resource + (int)f.basic + (int)f.udt == 1 &&
+         "only one resource analysis result should be expected");
+
+  // get the variable type
+  std::string typestr;
+  if (SamplerUAVOrSRV) {
+    QualType QT = SamplerUAVOrSRV->getType();
+    PrintingPolicy PP = S.getPrintingPolicy();
+    typestr = QualType::getAsString(QT.split(), PP);
+  } else
+    typestr = CBufferOrTBuffer->isCBuffer() ? "cbuffer" : "tbuffer";
----------------
bob80905 wrote:

Made a function to fetch the relevant data only at diagnostic emission time.

https://github.com/llvm/llvm-project/pull/97103


More information about the cfe-commits mailing list