[llvm] [DirectX] Infrastructure to collect shader flags for each function (PR #112967)

S. Bharadwaj Yadavalli via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 4 09:37:23 PST 2024


================
@@ -63,17 +118,78 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
   OS << ";\n";
 }
 
+void DXILModuleShaderFlagsInfo::clear() {
+  ModuleFlags = ComputedShaderFlags{};
+  FunctionFlags.clear();
+}
+
+/// Insert the pair <Func, FlagMask> into the sorted vector
+/// FunctionFlags. The insertion is expected to be in-order and hence
+/// is done at the end of the already sorted list.
+[[nodiscard]] bool DXILModuleShaderFlagsInfo::insertInorderFunctionFlags(
+    const Function *Func, ComputedShaderFlags FlagMask) {
+  std::pair<Function const *, ComputedShaderFlags> V{Func, {}};
+  auto Iter = llvm::lower_bound(FunctionFlags, V);
+  if (Iter != FunctionFlags.end())
+    return false;
+
+  FunctionFlags.push_back({Func, FlagMask});
+  return true;
+}
+
+SmallVector<std::pair<Function const *, ComputedShaderFlags>>
+DXILModuleShaderFlagsInfo::getFunctionFlags() const {
+  return FunctionFlags;
+}
+
+ComputedShaderFlags DXILModuleShaderFlagsInfo::getModuleFlags() const {
+  return ModuleFlags;
+}
+
+Expected<const ComputedShaderFlags &>
+DXILModuleShaderFlagsInfo::getShaderFlagsMask(const Function *Func) const {
+  std::pair<Function const *, ComputedShaderFlags> V{Func, {}};
+  // It is correct to delegate comparison of two pairs, say P1, P2, to default
+  // operator< for pairs that returns the evaluation of (P1.first < P2.first)
+  // viz., comparison of Function pointers - the same comparison criterion used
+  // for sorting module functions walked to form FunctionFLags vector..
+  auto Iter = llvm::lower_bound(FunctionFlags, V);
+  if (Iter == FunctionFlags.end()) {
+    return createStringError("Shader Flags information of Function '" +
+                             Func->getName() + "' not found");
+  }
----------------
bharadwajy wrote:

> This isn't the correct way to use `lower_bound`. I think you probably meant to use `binary_search` here with a custom `operator<` that compares `Function*`s with `pair<Function*, ComputedShaderFlags>`.
> 
> If you do want to use `lower_bound`, then you need to account for it returning either `end()`, or an iterator pointing to the first element that is not before the one you're searching for. So you'd normally see a test like this:
> 
> ```
> auto it = lower_bound(FunctionFlags, V);
> if (it == FunctionFlags.end() || it->first != Func) {
>    return error;
> }
> ```
> 
> As it is now, if you had this in your vector: `A B D` and you searched for `C` then you'd get an iterator pointing at `D` and so you'd return the flags for `D`.

Added check for the searched function.

https://github.com/llvm/llvm-project/pull/112967


More information about the llvm-commits mailing list