[llvm] [HLSL] Analyze update counter usage (PR #130356)

Ashley Coleman via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 26 15:47:47 PDT 2025


================
@@ -823,8 +827,147 @@ DXILBindingMap::findByUse(const Value *Key) const {
 
 //===----------------------------------------------------------------------===//
 
+static bool isUpdateCounterIntrinsic(Function &F) {
+  return F.getIntrinsicID() == Intrinsic::dx_resource_updatecounter;
+}
+
+void DXILResourceCounterDirectionMap::populate(Module &M, DXILBindingMap &DBM) {
+  SmallVector<std::tuple<dxil::ResourceBindingInfo, ResourceCounterDirection,
+                         const Function *, const CallInst *>>
+      DiagCounterDirs;
+
+  for (Function &F : M.functions()) {
+    if (!isUpdateCounterIntrinsic(F))
+      continue;
+
+    for (const User *U : F.users()) {
+      const CallInst *CI = dyn_cast<CallInst>(U);
+      assert(CI && "Users of dx_resource_updateCounter must be call instrs");
+
+      // Determine if the use is an increment or decrement
+      Value *CountArg = CI->getArgOperand(1);
+      ConstantInt *CountValue = cast<ConstantInt>(CountArg);
+      int64_t CountLiteral = CountValue->getSExtValue();
+
+      // 0 is an unknown direction and shouldn't result in an insert
+      if (CountLiteral == 0)
+        continue;
+
+      ResourceCounterDirection Direction = ResourceCounterDirection::Decrement;
+      if (CountLiteral > 0)
+        Direction = ResourceCounterDirection::Increment;
+
+      // Collect all potential creation points for the handle arg
+      Value *HandleArg = CI->getArgOperand(0);
+      SmallVector<dxil::ResourceBindingInfo> RBInfos = DBM.findByUse(HandleArg);
+      for (const dxil::ResourceBindingInfo RBInfo : RBInfos)
+        DiagCounterDirs.emplace_back(RBInfo, Direction, &F, CI);
+    }
+  }
+
+  // Sort by the Binding and Direction for fast lookup
+  std::sort(DiagCounterDirs.begin(), DiagCounterDirs.end(),
+            [](const auto &LHS, const auto &RHS) {
+              const auto L = std::pair{std::get<dxil::ResourceBindingInfo>(LHS),
+                                       std::get<ResourceCounterDirection>(LHS)};
+              const auto R = std::pair{std::get<dxil::ResourceBindingInfo>(RHS),
+                                       std::get<ResourceCounterDirection>(RHS)};
+              return L < R;
+            });
+
+  // Remove the duplicate entries. Since direction is considered for equality
+  // a unique resource with more than one direction will not be deduped.
+  auto *const UniqueEnd = std::unique(
+      DiagCounterDirs.begin(), DiagCounterDirs.end(),
+      [](const auto &LHS, const auto &RHS) {
+        const auto L = std::pair{std::get<dxil::ResourceBindingInfo>(LHS),
+                                 std::get<ResourceCounterDirection>(LHS)};
+        const auto R = std::pair{std::get<dxil::ResourceBindingInfo>(RHS),
+                                 std::get<ResourceCounterDirection>(RHS)};
+        return L == R;
+      });
+
+  // Actually erase the invalidated items
+  DiagCounterDirs.erase(UniqueEnd, DiagCounterDirs.end());
+
+  // If any duplicate entries still exist at this point then it must be a
+  // resource that was both incremented and decremented which is not allowed.
+  // Mark all those entries as invalid.
+  {
+    auto *DupFirst = DiagCounterDirs.begin();
+    auto *DupNext = DupFirst + 1;
+    auto *DupLast = DiagCounterDirs.end();
+    for (; DupFirst < DupLast && DupNext < DupLast; ++DupFirst, ++DupNext) {
+      if (std::get<dxil::ResourceBindingInfo>(*DupFirst) ==
+          std::get<dxil::ResourceBindingInfo>(*DupNext)) {
+        std::get<ResourceCounterDirection>(*DupFirst) =
+            ResourceCounterDirection::Invalid;
+        std::get<ResourceCounterDirection>(*DupNext) =
+            ResourceCounterDirection::Invalid;
+      }
+    }
+  }
+
+  // Raise an error for every invalid entry
+  for (const auto &Entry : DiagCounterDirs) {
+    ResourceCounterDirection Dir = std::get<ResourceCounterDirection>(Entry);
+    const Function *F = std::get<const Function *>(Entry);
+    const CallInst *CI = std::get<const CallInst *>(Entry);
+
+    if (Dir != ResourceCounterDirection::Invalid)
+      continue;
+
+    StringRef Message = "RWStructuredBuffers may increment or decrement their "
+                        "counters, but not both.";
+    M.getContext().diagnose(
+        DiagnosticInfoGenericWithLoc(Message, *F, CI->getDebugLoc()));
+  }
+
+  // Copy the results into the final vec
+  CounterDirections.clear();
+  CounterDirections.reserve(DiagCounterDirs.size());
+  std::transform(DiagCounterDirs.begin(), DiagCounterDirs.end(),
+                 std::back_inserter(CounterDirections), [](const auto &Item) {
+                   return std::pair{std::get<dxil::ResourceBindingInfo>(Item),
+                                    std::get<ResourceCounterDirection>(Item)};
+                 });
+}
+
+void DXILResourceCounterDirectionWrapperPass::getAnalysisUsage(
+    AnalysisUsage &AU) const {
+  AU.addRequiredTransitive<DXILResourceBindingWrapperPass>();
+  AU.setPreservesAll();
+}
+
+bool DXILResourceCounterDirectionWrapperPass::runOnModule(Module &M) {
+  Map.reset(new DXILResourceCounterDirectionMap());
+
+  auto DBM = getAnalysis<DXILResourceBindingWrapperPass>().getBindingMap();
+  Map->populate(M, DBM);
+
+  return false;
+}
+
+void DXILResourceCounterDirectionWrapperPass::releaseMemory() { Map.reset(); }
+
+void DXILResourceCounterDirectionWrapperPass::print(raw_ostream &OS,
+                                                    const Module *M) const {
+  if (!Map) {
+    OS << "No resource directions have been built!\n";
+    return;
+  }
+  // Map->print(OS, *DRTM, M->getDataLayout());
----------------
V-FEXrt wrote:

Chatted offline and it looks like we'll just remove the print function instead of implemented another pass since the print function is normally used for testing and the functionality is tested in other ways.

https://github.com/llvm/llvm-project/pull/130356


More information about the llvm-commits mailing list