[llvm] [HLSL] Add support to lookup a ResourceBindingInfo from its use (PR #126556)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Feb 14 14:25:46 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-analysis
Author: Ashley Coleman (V-FEXrt)
<details>
<summary>Changes</summary>
Adds `findByUse` which takes a `llvm::Value` from a use and resolves it (as best as possible) back to the creation of that resource.
It may return multiple ResourceBindingInfo if the use comes from branched control flow.
Fixes #<!-- -->125746
---
Full diff: https://github.com/llvm/llvm-project/pull/126556.diff
4 Files Affected:
- (modified) llvm/include/llvm/Analysis/DXILResource.h (+4)
- (modified) llvm/lib/Analysis/DXILResource.cpp (+44)
- (modified) llvm/unittests/Target/DirectX/CMakeLists.txt (+2)
- (added) llvm/unittests/Target/DirectX/UniqueResourceFromUseTests.cpp (+309)
``````````diff
diff --git a/llvm/include/llvm/Analysis/DXILResource.h b/llvm/include/llvm/Analysis/DXILResource.h
index 87c5615c28ee0..9e1e3a6dfc50b 100644
--- a/llvm/include/llvm/Analysis/DXILResource.h
+++ b/llvm/include/llvm/Analysis/DXILResource.h
@@ -446,6 +446,10 @@ class DXILBindingMap {
return Pos == CallMap.end() ? Infos.end() : (Infos.begin() + Pos->second);
}
+ // Resoloves the use of a resource handle into the unique description of that
+ // resource by deduping calls to create.
+ SmallVector<dxil::ResourceBindingInfo> findByUse(const Value *Key) const;
+
const_iterator find(const CallInst *Key) const {
auto Pos = CallMap.find(Key);
return Pos == CallMap.end() ? Infos.end() : (Infos.begin() + Pos->second);
diff --git a/llvm/lib/Analysis/DXILResource.cpp b/llvm/lib/Analysis/DXILResource.cpp
index 7f28e63cc117d..25ff7db7a4d71 100644
--- a/llvm/lib/Analysis/DXILResource.cpp
+++ b/llvm/lib/Analysis/DXILResource.cpp
@@ -770,6 +770,50 @@ void DXILBindingMap::print(raw_ostream &OS, DXILResourceTypeMap &DRTM,
}
}
+SmallVector<dxil::ResourceBindingInfo>
+DXILBindingMap::findByUse(const Value *Key) const {
+ const PHINode *Phi = dyn_cast<PHINode>(Key);
+ if (Phi) {
+ SmallVector<dxil::ResourceBindingInfo> Children;
+ for (const Value *V : Phi->operands()) {
+ Children.append(findByUse(V));
+ }
+ return Children;
+ }
+
+ const CallInst *CI = dyn_cast<CallInst>(Key);
+ if (!CI) {
+ return {};
+ }
+
+ const Type *UseType = CI->getType();
+
+ switch (CI->getIntrinsicID()) {
+ // Check if any of the parameters are the resource we are following. If so
+ // keep searching
+ case Intrinsic::not_intrinsic: {
+ SmallVector<dxil::ResourceBindingInfo> Children;
+ for (const Value *V : CI->args()) {
+ if (V->getType() != UseType) {
+ continue;
+ }
+
+ Children.append(findByUse(V));
+ }
+
+ return Children;
+ }
+ // Found the create, return the binding
+ case Intrinsic::dx_resource_handlefrombinding:
+ const auto *It = find(CI);
+ if (It == Infos.end())
+ return {};
+ return {*It};
+ }
+
+ return {};
+}
+
//===----------------------------------------------------------------------===//
AnalysisKey DXILResourceTypeAnalysis::Key;
diff --git a/llvm/unittests/Target/DirectX/CMakeLists.txt b/llvm/unittests/Target/DirectX/CMakeLists.txt
index 626c0d6384268..fd0d5a0dd52c1 100644
--- a/llvm/unittests/Target/DirectX/CMakeLists.txt
+++ b/llvm/unittests/Target/DirectX/CMakeLists.txt
@@ -8,10 +8,12 @@ set(LLVM_LINK_COMPONENTS
Core
DirectXCodeGen
DirectXPointerTypeAnalysis
+ Passes
Support
)
add_llvm_target_unittest(DirectXTests
CBufferDataLayoutTests.cpp
PointerTypeAnalysisTests.cpp
+ UniqueResourceFromUseTests.cpp
)
diff --git a/llvm/unittests/Target/DirectX/UniqueResourceFromUseTests.cpp b/llvm/unittests/Target/DirectX/UniqueResourceFromUseTests.cpp
new file mode 100644
index 0000000000000..5ad7330f05a45
--- /dev/null
+++ b/llvm/unittests/Target/DirectX/UniqueResourceFromUseTests.cpp
@@ -0,0 +1,309 @@
+//===- llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "DirectXIRPasses/PointerTypeAnalysis.h"
+#include "DirectXTargetMachine.h"
+#include "llvm/Analysis/DXILResource.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/CodeGen/CommandFlags.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/TypedPointerType.h"
+#include "llvm/MC/TargetRegistry.h"
+#include "llvm/Passes/PassBuilder.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/CodeGen.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Transforms/Utils/Debugify.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include <optional>
+
+using ::testing::Contains;
+using ::testing::Pair;
+
+using namespace llvm;
+using namespace llvm::dxil;
+
+template <typename T> struct IsA {
+ friend bool operator==(const Value *V, const IsA &) { return isa<T>(V); }
+};
+
+namespace {
+class UniqueResourceFromUseTest : public testing::Test {
+protected:
+ PassBuilder *PB;
+ ModuleAnalysisManager *MAM;
+
+ virtual void SetUp() {
+ MAM = new ModuleAnalysisManager();
+ PB = new PassBuilder();
+ PB->registerModuleAnalyses(*MAM);
+ MAM->registerPass([&] { return DXILResourceTypeAnalysis(); });
+ MAM->registerPass([&] { return DXILResourceBindingAnalysis(); });
+ }
+
+ virtual void TearDown() {
+ delete PB;
+ delete MAM;
+ }
+};
+
+TEST_F(UniqueResourceFromUseTest, TestTrivialUse) {
+ StringRef Assembly = R"(
+define void @main() {
+entry:
+ %handle = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 1, i32 2, i32 3, i32 4, i1 false)
+ call void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
+ call void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
+ ret void
+}
+
+declare target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32, i32, i32, i32, i1)
+declare void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
+ )";
+
+ LLVMContext Context;
+ SMDiagnostic Error;
+ auto M = parseAssemblyString(Assembly, Error, Context);
+ ASSERT_TRUE(M) << "Bad assembly?";
+
+ const DXILBindingMap &DBM = MAM->getResult<DXILResourceBindingAnalysis>(*M);
+ for (const Function &F : M->functions()) {
+ if (F.getName() != "a.func") {
+ continue;
+ }
+
+ unsigned CalledResources = 0;
+
+ for (const User *U : F.users()) {
+ const CallInst *CI = dyn_cast<CallInst>(U);
+ ASSERT_TRUE(CI) << "All users of @a.func must be CallInst";
+
+ const Value *Handle = CI->getArgOperand(0);
+
+ const auto Bindings = DBM.findByUse(Handle);
+ ASSERT_EQ(Bindings.size(), 1u)
+ << "Handle should resolve into one resource";
+
+ auto Binding = Bindings[0].getBinding();
+ EXPECT_EQ(0u, Binding.RecordID);
+ EXPECT_EQ(1u, Binding.Space);
+ EXPECT_EQ(2u, Binding.LowerBound);
+ EXPECT_EQ(3u, Binding.Size);
+
+ CalledResources++;
+ }
+
+ EXPECT_EQ(2u, CalledResources)
+ << "Expected 2 resolved call to create resource";
+ }
+}
+
+TEST_F(UniqueResourceFromUseTest, TestIndirectUse) {
+ StringRef Assembly = R"(
+define void @foo() {
+ %handle = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 1, i32 2, i32 3, i32 4, i1 false)
+ %handle2 = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %handle)
+ %handle3 = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %handle2)
+ %handle4 = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %handle3)
+ call void @a.func(target("dx.RawBuffer", float, 1, 0) %handle4)
+ ret void
+}
+
+declare target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32, i32, i32, i32, i1)
+declare void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
+declare target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %handle)
+ )";
+
+ LLVMContext Context;
+ SMDiagnostic Error;
+ auto M = parseAssemblyString(Assembly, Error, Context);
+ ASSERT_TRUE(M) << "Bad assembly?";
+
+ const DXILBindingMap &DBM = MAM->getResult<DXILResourceBindingAnalysis>(*M);
+ for (const Function &F : M->functions()) {
+ if (F.getName() != "a.func") {
+ continue;
+ }
+
+ unsigned CalledResources = 0;
+
+ for (const User *U : F.users()) {
+ const CallInst *CI = dyn_cast<CallInst>(U);
+ ASSERT_TRUE(CI) << "All users of @a.func must be CallInst";
+
+ const Value *Handle = CI->getArgOperand(0);
+
+ const auto Bindings = DBM.findByUse(Handle);
+ ASSERT_EQ(Bindings.size(), 1u)
+ << "Handle should resolve into one resource";
+
+ auto Binding = Bindings[0].getBinding();
+ EXPECT_EQ(0u, Binding.RecordID);
+ EXPECT_EQ(1u, Binding.Space);
+ EXPECT_EQ(2u, Binding.LowerBound);
+ EXPECT_EQ(3u, Binding.Size);
+
+ CalledResources++;
+ }
+
+ EXPECT_EQ(1u, CalledResources)
+ << "Expected 1 resolved call to create resource";
+ }
+}
+
+TEST_F(UniqueResourceFromUseTest, TestAmbigousIndirectUse) {
+ StringRef Assembly = R"(
+define void @foo() {
+ %foo = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 1, i32 1, i32 1, i32 1, i1 false)
+ %bar = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 2, i32 2, i32 2, i32 2, i1 false)
+ %baz = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 3, i32 3, i32 3, i32 3, i1 false)
+ %bat = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 4, i32 4, i32 4, i32 4, i1 false)
+ %a = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %foo, target("dx.RawBuffer", float, 1, 0) %bar)
+ %b = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %baz, target("dx.RawBuffer", float, 1, 0) %bat)
+ %handle = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %a, target("dx.RawBuffer", float, 1, 0) %b)
+ call void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
+ ret void
+}
+
+declare target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32, i32, i32, i32, i1)
+declare void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
+declare target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %x, target("dx.RawBuffer", float, 1, 0) %y)
+ )";
+
+ LLVMContext Context;
+ SMDiagnostic Error;
+ auto M = parseAssemblyString(Assembly, Error, Context);
+ ASSERT_TRUE(M) << "Bad assembly?";
+
+ const DXILBindingMap &DBM = MAM->getResult<DXILResourceBindingAnalysis>(*M);
+ for (const Function &F : M->functions()) {
+ if (F.getName() != "a.func") {
+ continue;
+ }
+
+ unsigned CalledResources = 0;
+
+ for (const User *U : F.users()) {
+ const CallInst *CI = dyn_cast<CallInst>(U);
+ ASSERT_TRUE(CI) << "All users of @a.func must be CallInst";
+
+ const Value *Handle = CI->getArgOperand(0);
+
+ const auto Bindings = DBM.findByUse(Handle);
+ ASSERT_EQ(Bindings.size(), 4u)
+ << "Handle should resolve into four resources";
+
+ auto Binding = Bindings[0].getBinding();
+ EXPECT_EQ(0u, Binding.RecordID);
+ EXPECT_EQ(1u, Binding.Space);
+ EXPECT_EQ(1u, Binding.LowerBound);
+ EXPECT_EQ(1u, Binding.Size);
+
+ Binding = Bindings[1].getBinding();
+ EXPECT_EQ(1u, Binding.RecordID);
+ EXPECT_EQ(2u, Binding.Space);
+ EXPECT_EQ(2u, Binding.LowerBound);
+ EXPECT_EQ(2u, Binding.Size);
+
+ Binding = Bindings[2].getBinding();
+ EXPECT_EQ(2u, Binding.RecordID);
+ EXPECT_EQ(3u, Binding.Space);
+ EXPECT_EQ(3u, Binding.LowerBound);
+ EXPECT_EQ(3u, Binding.Size);
+
+ Binding = Bindings[3].getBinding();
+ EXPECT_EQ(3u, Binding.RecordID);
+ EXPECT_EQ(4u, Binding.Space);
+ EXPECT_EQ(4u, Binding.LowerBound);
+ EXPECT_EQ(4u, Binding.Size);
+
+ CalledResources++;
+ }
+
+ EXPECT_EQ(1u, CalledResources)
+ << "Expected 1 resolved call to create resource";
+ }
+}
+
+TEST_F(UniqueResourceFromUseTest, TestConditionalUse) {
+ StringRef Assembly = R"(
+define void @foo(i32 %n) {
+entry:
+ %x = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 1, i32 1, i32 1, i32 1, i1 false)
+ %y = call target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32 4, i32 4, i32 4, i32 4, i1 false)
+ %cond = icmp eq i32 %n, 0
+ br i1 %cond, label %bb.true, label %bb.false
+
+bb.true:
+ %handle_t = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %x)
+ br label %bb.exit
+
+bb.false:
+ %handle_f = call target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %y)
+ br label %bb.exit
+
+bb.exit:
+ %handle = phi target("dx.RawBuffer", float, 1, 0) [ %handle_t, %bb.true ], [ %handle_f, %bb.false ]
+ call void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
+ ret void
+}
+
+declare target("dx.RawBuffer", float, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_f32_1_0t(i32, i32, i32, i32, i1)
+declare void @a.func(target("dx.RawBuffer", float, 1, 0) %handle)
+declare target("dx.RawBuffer", float, 1, 0) @ind.func(target("dx.RawBuffer", float, 1, 0) %x)
+ )";
+
+ LLVMContext Context;
+ SMDiagnostic Error;
+ auto M = parseAssemblyString(Assembly, Error, Context);
+ ASSERT_TRUE(M) << "Bad assembly?";
+
+ const DXILBindingMap &DBM = MAM->getResult<DXILResourceBindingAnalysis>(*M);
+ for (const Function &F : M->functions()) {
+ if (F.getName() != "a.func") {
+ continue;
+ }
+
+ unsigned CalledResources = 0;
+
+ for (const User *U : F.users()) {
+ const CallInst *CI = dyn_cast<CallInst>(U);
+ ASSERT_TRUE(CI) << "All users of @a.func must be CallInst";
+
+ const Value *Handle = CI->getArgOperand(0);
+
+ const auto Bindings = DBM.findByUse(Handle);
+ ASSERT_EQ(Bindings.size(), 2u)
+ << "Handle should resolve into four resources";
+
+ auto Binding = Bindings[0].getBinding();
+ EXPECT_EQ(0u, Binding.RecordID);
+ EXPECT_EQ(1u, Binding.Space);
+ EXPECT_EQ(1u, Binding.LowerBound);
+ EXPECT_EQ(1u, Binding.Size);
+
+ Binding = Bindings[1].getBinding();
+ EXPECT_EQ(1u, Binding.RecordID);
+ EXPECT_EQ(4u, Binding.Space);
+ EXPECT_EQ(4u, Binding.LowerBound);
+ EXPECT_EQ(4u, Binding.Size);
+
+ CalledResources++;
+ }
+
+ EXPECT_EQ(1u, CalledResources)
+ << "Expected 1 resolved call to create resource";
+ }
+}
+
+} // namespace
``````````
</details>
https://github.com/llvm/llvm-project/pull/126556
More information about the llvm-commits
mailing list