[llvm] [HLSL] Add support to lookup a ResourceBindingInfo from its use (PR #126556)

via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 14 14:25:45 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-directx

Author: Ashley Coleman (V-FEXrt)

<details>
<summary>Changes</summary>

Adds `findByUse` which takes a `llvm::Value` from a use and resolves it (as best as possible) back to the creation of that resource.

It may return multiple ResourceBindingInfo if the use comes from branched control flow.

Fixes #<!-- -->125746 

---
Full diff: https://github.com/llvm/llvm-project/pull/126556.diff


4 Files Affected:

- (modified) llvm/include/llvm/Analysis/DXILResource.h (+4) 
- (modified) llvm/lib/Analysis/DXILResource.cpp (+44) 
- (modified) llvm/unittests/Target/DirectX/CMakeLists.txt (+2) 
- (added) llvm/unittests/Target/DirectX/UniqueResourceFromUseTests.cpp (+309) 


``````````diff
diff --git a/llvm/include/llvm/Analysis/DXILResource.h b/llvm/include/llvm/Analysis/DXILResource.h
index 87c5615c28ee0..9e1e3a6dfc50b 100644
--- a/llvm/include/llvm/Analysis/DXILResource.h
+++ b/llvm/include/llvm/Analysis/DXILResource.h
@@ -446,6 +446,10 @@ class DXILBindingMap {
     return Pos == CallMap.end() ? Infos.end() : (Infos.begin() + Pos->second);
   }
 
+  // Resoloves the use of a resource handle into the unique description of that
+  // resource by deduping calls to create.
+  SmallVector<dxil::ResourceBindingInfo> findByUse(const Value *Key) const;
+
   const_iterator find(const CallInst *Key) const {
     auto Pos = CallMap.find(Key);
     return Pos == CallMap.end() ? Infos.end() : (Infos.begin() + Pos->second);
diff --git a/llvm/lib/Analysis/DXILResource.cpp b/llvm/lib/Analysis/DXILResource.cpp
index 7f28e63cc117d..25ff7db7a4d71 100644
--- a/llvm/lib/Analysis/DXILResource.cpp
+++ b/llvm/lib/Analysis/DXILResource.cpp
@@ -770,6 +770,50 @@ void DXILBindingMap::print(raw_ostream &OS, DXILResourceTypeMap &DRTM,
   }
 }
 
+SmallVector<dxil::ResourceBindingInfo>
+DXILBindingMap::findByUse(const Value *Key) const {
+  const PHINode *Phi = dyn_cast<PHINode>(Key);
+  if (Phi) {
+    SmallVector<dxil::ResourceBindingInfo> Children;
+    for (const Value *V : Phi->operands()) {
+      Children.append(findByUse(V));
+    }
+    return Children;
+  }
+
+  const CallInst *CI = dyn_cast<CallInst>(Key);
+  if (!CI) {
+    return {};
+  }
+
+  const Type *UseType = CI->getType();
+
+  switch (CI->getIntrinsicID()) {
+  // Check if any of the parameters are the resource we are following. If so
+  // keep searching
+  case Intrinsic::not_intrinsic: {
+    SmallVector<dxil::ResourceBindingInfo> Children;
+    for (const Value *V : CI->args()) {
+      if (V->getType() != UseType) {
+        continue;
+      }
+
+      Children.append(findByUse(V));
+    }
+
+    return Children;
+  }
+  // Found the create, return the binding
+  case Intrinsic::dx_resource_handlefrombinding:
+    const auto *It = find(CI);
+    if (It == Infos.end())
+      return {};
+    return {*It};
+  }
+
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 
 AnalysisKey DXILResourceTypeAnalysis::Key;
diff --git a/llvm/unittests/Target/DirectX/CMakeLists.txt b/llvm/unittests/Target/DirectX/CMakeLists.txt
index 626c0d6384268..fd0d5a0dd52c1 100644
--- a/llvm/unittests/Target/DirectX/CMakeLists.txt
+++ b/llvm/unittests/Target/DirectX/CMakeLists.txt
@@ -8,10 +8,12 @@ set(LLVM_LINK_COMPONENTS
   Core
   DirectXCodeGen
   DirectXPointerTypeAnalysis
+  Passes
   Support
   )
 
 add_llvm_target_unittest(DirectXTests
   CBufferDataLayoutTests.cpp
   PointerTypeAnalysisTests.cpp
+  UniqueResourceFromUseTests.cpp
   )
diff --git a/llvm/unittests/Target/DirectX/UniqueResourceFromUseTests.cpp b/llvm/unittests/Target/DirectX/UniqueResourceFromUseTests.cpp
new file mode 100644
index 0000000000000..5ad7330f05a45
--- /dev/null
+++ b/llvm/unittests/Target/DirectX/UniqueResourceFromUseTests.cpp
@@ -0,0 +1,309 @@
+//===- llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "DirectXIRPasses/PointerTypeAnalysis.h"
+#include "DirectXTargetMachine.h"
+#include "llvm/Analysis/DXILResource.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/CodeGen/CommandFlags.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/TypedPointerType.h"
+#include "llvm/MC/TargetRegistry.h"
+#include "llvm/Passes/PassBuilder.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/CodeGen.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Transforms/Utils/Debugify.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include <optional>
+
+using ::testing::Contains;
+using ::testing::Pair;
+
+using namespace llvm;
+using namespace llvm::dxil;
+
+template <typename T> struct IsA {
+  friend bool operator==(const Value *V, const IsA &) { return isa<T>(V); }
+};
+
+namespace {
+class UniqueResourceFromUseTest : public testing::Test {
+protected:
+  PassBuilder *PB;
+  ModuleAnalysisManager *MAM;
+
+  virtual void SetUp() {
+    MAM = new ModuleAnalysisManager();
+    PB = new PassBuilder();
+    PB->registerModuleAnalyses(*MAM);
+    MAM->registerPass([&] { return DXILResourceTypeAnalysis(); });
+    MAM->registerPass([&] { return DXILResourceBindingAnalysis(); });
+  }
+
+  virtual void TearDown() {
+    delete PB;
+    delete MAM;
+  }
+};
+
+TEST_F(UniqueResourceFromUseTest, TestTrivialUse) {
+  StringRef Assembly = R"(
+define void @main() {
+entry:
+  %handle = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 1, i32 2, i32 3, i32 4, i1 false)
+  call void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
+  call void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
+  ret void
+}
+
+declare target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32, i32, i32, i32, i1)
+declare void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
+  )";
+
+  LLVMContext Context;
+  SMDiagnostic Error;
+  auto M = parseAssemblyString(Assembly, Error, Context);
+  ASSERT_TRUE(M) << "Bad assembly?";
+
+  const DXILBindingMap &DBM = MAM->getResult<DXILResourceBindingAnalysis>(*M);
+  for (const Function &F : M->functions()) {
+    if (F.getName() != "a.func") {
+      continue;
+    }
+
+    unsigned CalledResources = 0;
+
+    for (const User *U : F.users()) {
+      const CallInst *CI = dyn_cast<CallInst>(U);
+      ASSERT_TRUE(CI) << "All users of @a.func must be CallInst";
+
+      const Value *Handle = CI->getArgOperand(0);
+
+      const auto Bindings = DBM.findByUse(Handle);
+      ASSERT_EQ(Bindings.size(), 1u)
+          << "Handle should resolve into one resource";
+
+      auto Binding = Bindings[0].getBinding();
+      EXPECT_EQ(0u, Binding.RecordID);
+      EXPECT_EQ(1u, Binding.Space);
+      EXPECT_EQ(2u, Binding.LowerBound);
+      EXPECT_EQ(3u, Binding.Size);
+
+      CalledResources++;
+    }
+
+    EXPECT_EQ(2u, CalledResources)
+        << "Expected 2 resolved call to create resource";
+  }
+}
+
+TEST_F(UniqueResourceFromUseTest, TestIndirectUse) {
+  StringRef Assembly = R"(
+define void @foo() {
+  %handle = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 1, i32 2, i32 3, i32 4, i1 false)
+  %handle2 = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %handle)
+  %handle3 = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %handle2)
+  %handle4 = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %handle3)
+  call void @a.func(target("dx.RawBuffer", float, 1, 0) %handle4)
+  ret void
+}
+
+declare target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32, i32, i32, i32, i1)
+declare void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
+declare target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %handle)
+  )";
+
+  LLVMContext Context;
+  SMDiagnostic Error;
+  auto M = parseAssemblyString(Assembly, Error, Context);
+  ASSERT_TRUE(M) << "Bad assembly?";
+
+  const DXILBindingMap &DBM = MAM->getResult<DXILResourceBindingAnalysis>(*M);
+  for (const Function &F : M->functions()) {
+    if (F.getName() != "a.func") {
+      continue;
+    }
+
+    unsigned CalledResources = 0;
+
+    for (const User *U : F.users()) {
+      const CallInst *CI = dyn_cast<CallInst>(U);
+      ASSERT_TRUE(CI) << "All users of @a.func must be CallInst";
+
+      const Value *Handle = CI->getArgOperand(0);
+
+      const auto Bindings = DBM.findByUse(Handle);
+      ASSERT_EQ(Bindings.size(), 1u)
+          << "Handle should resolve into one resource";
+
+      auto Binding = Bindings[0].getBinding();
+      EXPECT_EQ(0u, Binding.RecordID);
+      EXPECT_EQ(1u, Binding.Space);
+      EXPECT_EQ(2u, Binding.LowerBound);
+      EXPECT_EQ(3u, Binding.Size);
+
+      CalledResources++;
+    }
+
+    EXPECT_EQ(1u, CalledResources)
+        << "Expected 1 resolved call to create resource";
+  }
+}
+
+TEST_F(UniqueResourceFromUseTest, TestAmbigousIndirectUse) {
+  StringRef Assembly = R"(
+define void @foo() {
+  %foo = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 1, i32 1, i32 1, i32 1, i1 false)
+  %bar = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 2, i32 2, i32 2, i32 2, i1 false)
+  %baz = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 3, i32 3, i32 3, i32 3, i1 false)
+  %bat = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 4, i32 4, i32 4, i32 4, i1 false)
+  %a = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %foo, target("dx.RawBuffer", float, 1, 0) %bar)
+  %b = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %baz, target("dx.RawBuffer", float, 1, 0) %bat)
+  %handle = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %a, target("dx.RawBuffer", float, 1, 0) %b)
+  call void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
+  ret void
+}
+
+declare target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32, i32, i32, i32, i1)
+declare void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
+declare target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %x, target("dx.RawBuffer", float, 1, 0) %y)
+  )";
+
+  LLVMContext Context;
+  SMDiagnostic Error;
+  auto M = parseAssemblyString(Assembly, Error, Context);
+  ASSERT_TRUE(M) << "Bad assembly?";
+
+  const DXILBindingMap &DBM = MAM->getResult<DXILResourceBindingAnalysis>(*M);
+  for (const Function &F : M->functions()) {
+    if (F.getName() != "a.func") {
+      continue;
+    }
+
+    unsigned CalledResources = 0;
+
+    for (const User *U : F.users()) {
+      const CallInst *CI = dyn_cast<CallInst>(U);
+      ASSERT_TRUE(CI) << "All users of @a.func must be CallInst";
+
+      const Value *Handle = CI->getArgOperand(0);
+
+      const auto Bindings = DBM.findByUse(Handle);
+      ASSERT_EQ(Bindings.size(), 4u)
+          << "Handle should resolve into four resources";
+
+      auto Binding = Bindings[0].getBinding();
+      EXPECT_EQ(0u, Binding.RecordID);
+      EXPECT_EQ(1u, Binding.Space);
+      EXPECT_EQ(1u, Binding.LowerBound);
+      EXPECT_EQ(1u, Binding.Size);
+
+      Binding = Bindings[1].getBinding();
+      EXPECT_EQ(1u, Binding.RecordID);
+      EXPECT_EQ(2u, Binding.Space);
+      EXPECT_EQ(2u, Binding.LowerBound);
+      EXPECT_EQ(2u, Binding.Size);
+
+      Binding = Bindings[2].getBinding();
+      EXPECT_EQ(2u, Binding.RecordID);
+      EXPECT_EQ(3u, Binding.Space);
+      EXPECT_EQ(3u, Binding.LowerBound);
+      EXPECT_EQ(3u, Binding.Size);
+
+      Binding = Bindings[3].getBinding();
+      EXPECT_EQ(3u, Binding.RecordID);
+      EXPECT_EQ(4u, Binding.Space);
+      EXPECT_EQ(4u, Binding.LowerBound);
+      EXPECT_EQ(4u, Binding.Size);
+
+      CalledResources++;
+    }
+
+    EXPECT_EQ(1u, CalledResources)
+        << "Expected 1 resolved call to create resource";
+  }
+}
+
+TEST_F(UniqueResourceFromUseTest, TestConditionalUse) {
+  StringRef Assembly = R"(
+define void @foo(i32 %n) {
+entry:
+  %x = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 1, i32 1, i32 1, i32 1, i1 false)
+  %y = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 4, i32 4, i32 4, i32 4, i1 false)
+  %cond = icmp eq i32 %n, 0
+  br i1 %cond, label %bb.true, label %bb.false
+
+bb.true:
+  %handle_t = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %x)
+  br label %bb.exit
+
+bb.false:
+  %handle_f = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %y)
+  br label %bb.exit
+
+bb.exit:
+  %handle = phi target("dx.RawBuffer", float, 1, 0) [ %handle_t, %bb.true ], [ %handle_f, %bb.false ]
+  call void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
+  ret void
+}
+
+declare target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32, i32, i32, i32, i1)
+declare void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
+declare target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %x)
+  )";
+
+  LLVMContext Context;
+  SMDiagnostic Error;
+  auto M = parseAssemblyString(Assembly, Error, Context);
+  ASSERT_TRUE(M) << "Bad assembly?";
+
+  const DXILBindingMap &DBM = MAM->getResult<DXILResourceBindingAnalysis>(*M);
+  for (const Function &F : M->functions()) {
+    if (F.getName() != "a.func") {
+      continue;
+    }
+
+    unsigned CalledResources = 0;
+
+    for (const User *U : F.users()) {
+      const CallInst *CI = dyn_cast<CallInst>(U);
+      ASSERT_TRUE(CI) << "All users of @a.func must be CallInst";
+
+      const Value *Handle = CI->getArgOperand(0);
+
+      const auto Bindings = DBM.findByUse(Handle);
+      ASSERT_EQ(Bindings.size(), 2u)
+          << "Handle should resolve into four resources";
+
+      auto Binding = Bindings[0].getBinding();
+      EXPECT_EQ(0u, Binding.RecordID);
+      EXPECT_EQ(1u, Binding.Space);
+      EXPECT_EQ(1u, Binding.LowerBound);
+      EXPECT_EQ(1u, Binding.Size);
+
+      Binding = Bindings[1].getBinding();
+      EXPECT_EQ(1u, Binding.RecordID);
+      EXPECT_EQ(4u, Binding.Space);
+      EXPECT_EQ(4u, Binding.LowerBound);
+      EXPECT_EQ(4u, Binding.Size);
+
+      CalledResources++;
+    }
+
+    EXPECT_EQ(1u, CalledResources)
+        << "Expected 1 resolved call to create resource";
+  }
+}
+
+} // namespace

``````````

</details>


https://github.com/llvm/llvm-project/pull/126556


More information about the llvm-commits mailing list