[llvm] [DirectX] Propagate shader flags mask of callees to callers (PR #118306)
S. Bharadwaj Yadavalli via llvm-commits
llvm-commits at lists.llvm.org
Fri Dec 6 12:52:46 PST 2024
https://github.com/bharadwajy updated https://github.com/llvm/llvm-project/pull/118306
>From 6a4fef8b1c8ad78db63016ddc927bbbebeea04c5 Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Wed, 27 Nov 2024 11:25:34 -0500
Subject: [PATCH 1/5] Propagate shader flags mask of callees to callers
Add test to verify propagation of shader flags
---
llvm/lib/Target/DirectX/DXILShaderFlags.cpp | 43 ++++++++-
llvm/lib/Target/DirectX/DXILShaderFlags.h | 2 +
.../DirectX/ShaderFlags/double-extensions.ll | 7 ++
.../propagate-function-flags-test.ll | 92 +++++++++++++++++++
4 files changed, 140 insertions(+), 4 deletions(-)
create mode 100644 llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index d6917dce98abd5..f242204363cfe8 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -15,6 +15,7 @@
#include "DirectX.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
@@ -47,10 +48,14 @@ static void updateFunctionFlags(ComputedShaderFlags &CSF,
}
void ModuleShaderFlags::initialize(const Module &M) {
+ SmallVector<const Function *> WorkList;
// Collect shader flags for each of the functions
for (const auto &F : M.getFunctionList()) {
if (F.isDeclaration())
continue;
+ if (!F.user_empty()) {
+ WorkList.push_back(&F);
+ }
ComputedShaderFlags CSF;
for (const auto &BB : F)
for (const auto &I : BB)
@@ -61,6 +66,21 @@ void ModuleShaderFlags::initialize(const Module &M) {
CombinedSFMask.merge(CSF);
}
llvm::sort(FunctionFlags);
+ // Propagate shader flag mask of functions to their callers.
+ while (!WorkList.empty()) {
+ const Function *Func = WorkList.pop_back_val();
+ if (!Func->user_empty()) {
+ ComputedShaderFlags FuncSF = getFunctionFlags(Func);
+ // Update mask of callers with that of Func
+ for (const auto User : Func->users()) {
+ if (const CallInst *CI = dyn_cast<CallInst>(User)) {
+ const Function *Caller = CI->getParent()->getParent();
+ if (mergeFunctionShaderFlags(Caller, FuncSF))
+ WorkList.push_back(Caller);
+ }
+ }
+ }
+ }
}
void ComputedShaderFlags::print(raw_ostream &OS) const {
@@ -81,16 +101,31 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
OS << ";\n";
}
-/// Return the shader flags mask of the specified function Func.
-const ComputedShaderFlags &
-ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
+auto ModuleShaderFlags::getFunctionShaderFlagInfo(const Function *Func) const {
const auto Iter = llvm::lower_bound(
FunctionFlags, Func,
[](const std::pair<const Function *, ComputedShaderFlags> FSM,
const Function *FindFunc) { return (FSM.first < FindFunc); });
assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
"No Shader Flags Mask exists for function");
- return Iter->second;
+ return Iter;
+}
+
+/// Merge mask NewSF to that of Func, if different.
+/// Return true if mask of Func is changed, else false.
+bool ModuleShaderFlags::mergeFunctionShaderFlags(
+ const Function *Func, const ComputedShaderFlags NewSF) {
+ const auto FuncSFInfo = getFunctionShaderFlagInfo(Func);
+ if ((FuncSFInfo->second & NewSF) != NewSF) {
+ const_cast<ComputedShaderFlags &>(FuncSFInfo->second).merge(NewSF);
+ return true;
+ }
+ return false;
+}
+/// Return the shader flags mask of the specified function Func.
+const ComputedShaderFlags &
+ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
+ return getFunctionShaderFlagInfo(Func)->second;
}
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h
index 2d60137f8b191c..8c581f243ca98b 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.h
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h
@@ -95,6 +95,8 @@ struct ModuleShaderFlags {
SmallVector<std::pair<Function const *, ComputedShaderFlags>> FunctionFlags;
/// Combined Shader Flag Mask of all functions of the module
ComputedShaderFlags CombinedSFMask{};
+ auto getFunctionShaderFlagInfo(const Function *) const;
+ bool mergeFunctionShaderFlags(const Function *, ComputedShaderFlags);
};
class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> {
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
index 6332ef806a0d8f..8e5e61b42469ad 100644
--- a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
@@ -12,6 +12,13 @@ target triple = "dxil-pc-shadermodel6.7-library"
; CHECK-NEXT: ;
; CHECK-NEXT: ; Shader Flags for Module Functions
+;CHECK: ; Function top_level : 0x00000044
+define void @top_level() #0 {
+ call void @test_uitofp_i64(i64 noundef 5)
+ ret void
+}
+
+
; CHECK: ; Function test_fdiv_double : 0x00000044
define double @test_fdiv_double(double %a, double %b) #0 {
%res = fdiv double %a, %b
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll
new file mode 100644
index 00000000000000..93d634c0384ae7
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll
@@ -0,0 +1,92 @@
+; RUN: opt -S --passes="print-dx-shader-flags" 2>&1 %s | FileCheck %s
+
+target triple = "dxil-pc-shadermodel6.7-library"
+
+; CHECK: ; Combined Shader Flags for Module
+; CHECK-NEXT: ; Shader Flags Value: 0x00000044
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Note: shader requires additional functionality:
+; CHECK-NEXT: ; Double-precision floating point
+; CHECK-NEXT: ; Double-precision extensions for 11.1
+; CHECK-NEXT: ; Note: extra DXIL module flags:
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Shader Flags for Module Functions
+
+; CHECK: ; Function call_n6 : 0x00000044
+define double @call_n6(i32 noundef %0) local_unnamed_addr #0 {
+ %2 = sitofp i32 %0 to double
+ ret double %2
+}
+; CHECK: ; Function call_n4 : 0x00000044
+define double @call_n4(i32 noundef %0) local_unnamed_addr #0 {
+ %2 = tail call double @call_n6(i32 noundef %0)
+ ret double %2
+}
+
+; CHECK: ; Function call_n7 : 0x00000044
+define double @call_n7(i32 noundef %0) local_unnamed_addr #0 {
+ %2 = uitofp i32 %0 to double
+ ret double %2
+}
+
+; CHECK: ; Function call_n5 : 0x00000044
+define double @call_n5(i32 noundef %0) local_unnamed_addr #0 {
+ %2 = tail call double @call_n7(i32 noundef %0)
+ ret double %2
+}
+
+; CHECK: ; Function call_n2 : 0x00000044
+define double @call_n2(i64 noundef %0) local_unnamed_addr #0 {
+ %2 = icmp ult i64 %0, 6
+ br i1 %2, label %3, label %7
+
+3: ; preds = %1
+ %4 = add nuw nsw i64 %0, 1
+ %5 = uitofp i64 %4 to double
+ %6 = tail call double @call_n1(double noundef %5)
+ br label %10
+
+7: ; preds = %1
+ %8 = trunc i64 %0 to i32
+ %9 = tail call double @call_n4(i32 noundef %8)
+ br label %10
+
+10: ; preds = %7, %3
+ %11 = phi double [ %6, %3 ], [ %9, %7 ]
+ ret double %11
+}
+
+; CHECK: ; Function call_n1 : 0x00000044
+define double @call_n1(double noundef %0) local_unnamed_addr #0 {
+ %2 = fcmp ugt double %0, 5.000000e+00
+ br i1 %2, label %6, label %3
+
+3: ; preds = %1
+ %4 = fptoui double %0 to i64
+ %5 = tail call double @call_n2(i64 noundef %4)
+ br label %9
+
+6: ; preds = %1
+ %7 = fptoui double %0 to i32
+ %8 = tail call double @call_n5(i32 noundef %7)
+ br label %9
+
+9: ; preds = %6, %3
+ %10 = phi double [ %5, %3 ], [ %8, %6 ]
+ ret double %10
+}
+
+; CHECK: ; Function call_n3 : 0x00000044
+define double @call_n3(double noundef %0) local_unnamed_addr #0 {
+ %2 = fdiv double %0, 3.000000e+00
+ ret double %2
+}
+
+; CHECK: ; Function main : 0x00000044
+define i32 @main() local_unnamed_addr #0 {
+ %1 = tail call double @call_n1(double noundef 1.000000e+00)
+ %2 = tail call double @call_n3(double noundef %1)
+ ret i32 0
+}
+
+attributes #0 = { convergent norecurse nounwind "hlsl.export"}
>From a360efc5d0a90b4bf36384fa779df01c07bba1fd Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Tue, 3 Dec 2024 19:52:07 -0500
Subject: [PATCH 2/5] Address PR feedback
---
llvm/lib/Target/DirectX/DXILShaderFlags.cpp | 48 +++++++++++----------
llvm/lib/Target/DirectX/DXILShaderFlags.h | 10 ++---
2 files changed, 30 insertions(+), 28 deletions(-)
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index f242204363cfe8..44fa5a2bb060b6 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -70,7 +70,7 @@ void ModuleShaderFlags::initialize(const Module &M) {
while (!WorkList.empty()) {
const Function *Func = WorkList.pop_back_val();
if (!Func->user_empty()) {
- ComputedShaderFlags FuncSF = getFunctionFlags(Func);
+ const ComputedShaderFlags &FuncSF = getFunctionFlags(Func);
// Update mask of callers with that of Func
for (const auto User : Func->users()) {
if (const CallInst *CI = dyn_cast<CallInst>(User)) {
@@ -101,31 +101,35 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
OS << ";\n";
}
-auto ModuleShaderFlags::getFunctionShaderFlagInfo(const Function *Func) const {
- const auto Iter = llvm::lower_bound(
- FunctionFlags, Func,
- [](const std::pair<const Function *, ComputedShaderFlags> FSM,
- const Function *FindFunc) { return (FSM.first < FindFunc); });
- assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
- "No Shader Flags Mask exists for function");
- return Iter;
+static bool compareShaderFlagsInfo(
+ const std::pair<const Function *, ComputedShaderFlags> FSM,
+ const Function *FindFunc) {
+ return (FSM.first < FindFunc);
}
-/// Merge mask NewSF to that of Func, if different.
-/// Return true if mask of Func is changed, else false.
-bool ModuleShaderFlags::mergeFunctionShaderFlags(
- const Function *Func, const ComputedShaderFlags NewSF) {
- const auto FuncSFInfo = getFunctionShaderFlagInfo(Func);
- if ((FuncSFInfo->second & NewSF) != NewSF) {
- const_cast<ComputedShaderFlags &>(FuncSFInfo->second).merge(NewSF);
- return true;
- }
- return false;
-}
/// Return the shader flags mask of the specified function Func.
const ComputedShaderFlags &
ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
- return getFunctionShaderFlagInfo(Func)->second;
+ const std::pair<const Function *, ComputedShaderFlags> *Iter =
+ llvm::lower_bound(FunctionFlags, Func, compareShaderFlagsInfo);
+ assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
+ "Get Shader Flags : No Shader Flags Mask exists for function");
+ return Iter->second;
+}
+
+/// Merge specified shader flags mask SF with current mask of the specified
+/// function Func.
+/// Return true if merge operation changes the value of shader flags mask of
+/// Func; else false.
+bool ModuleShaderFlags::mergeFunctionShaderFlags(const Function *Func,
+ ComputedShaderFlags SF) {
+ std::pair<const Function *, ComputedShaderFlags> *Iter =
+ llvm::lower_bound(FunctionFlags, Func, compareShaderFlagsInfo);
+ assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
+ "Merge Shader Flags : No Shader Flags Mask exists for function");
+ ComputedShaderFlags PreMergeSF = Iter->second;
+ Iter->second.merge(SF);
+ return (PreMergeSF != Iter->second);
}
//===----------------------------------------------------------------------===//
@@ -152,7 +156,7 @@ PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
for (const auto &F : M.getFunctionList()) {
if (F.isDeclaration())
continue;
- auto SFMask = FlagsInfo.getFunctionFlags(&F);
+ const ComputedShaderFlags &SFMask = FlagsInfo.getFunctionFlags(&F);
OS << formatv("; Function {0} : {1:x8}\n;\n", F.getName(),
(uint64_t)(SFMask));
}
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h
index 8c581f243ca98b..158e4ea9033364 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.h
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h
@@ -70,11 +70,10 @@ struct ComputedShaderFlags {
return FeatureFlags;
}
- void merge(const uint64_t IVal) {
+ void merge(const ComputedShaderFlags CSF) {
#define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str) \
- FlagName |= (IVal & getMask(DxilModuleBit));
-#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) \
- FlagName |= (IVal & getMask(DxilModuleBit));
+ FlagName |= CSF.FlagName;
+#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) FlagName |= CSF.FlagName;
#include "llvm/BinaryFormat/DXContainerConstants.def"
return;
}
@@ -92,10 +91,9 @@ struct ModuleShaderFlags {
/// Vector of sorted Function-Shader Flag mask pairs representing properties
/// of each of the functions in the module. Shader Flags of each function
/// represent both module-level and function-level flags
- SmallVector<std::pair<Function const *, ComputedShaderFlags>> FunctionFlags;
+ SmallVector<std::pair<const Function *, ComputedShaderFlags>> FunctionFlags;
/// Combined Shader Flag Mask of all functions of the module
ComputedShaderFlags CombinedSFMask{};
- auto getFunctionShaderFlagInfo(const Function *) const;
bool mergeFunctionShaderFlags(const Function *, ComputedShaderFlags);
};
>From b7cb3cd12a5816f55037e55cb981e10cdc462fbf Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Wed, 4 Dec 2024 15:07:11 -0500
Subject: [PATCH 3/5] Address PR feedback - Use DenseMap instead of SmallVector
for Function-Shader Flags Mask map - Return true if
ComputedShaderFlags::merge() if value is modified - Style nits
---
llvm/lib/Target/DirectX/DXILShaderFlags.cpp | 48 +++++++++------------
llvm/lib/Target/DirectX/DXILShaderFlags.h | 20 ++++++---
2 files changed, 33 insertions(+), 35 deletions(-)
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index 44fa5a2bb060b6..3e1009124ceeca 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -53,31 +53,33 @@ void ModuleShaderFlags::initialize(const Module &M) {
for (const auto &F : M.getFunctionList()) {
if (F.isDeclaration())
continue;
- if (!F.user_empty()) {
+
+ if (!F.user_empty())
WorkList.push_back(&F);
- }
+
ComputedShaderFlags CSF;
for (const auto &BB : F)
for (const auto &I : BB)
updateFunctionFlags(CSF, I);
// Insert shader flag mask for function F
- FunctionFlags.push_back({&F, CSF});
+ FunctionFlags.insert({&F, CSF});
// Update combined shader flags mask
CombinedSFMask.merge(CSF);
}
- llvm::sort(FunctionFlags);
+
// Propagate shader flag mask of functions to their callers.
while (!WorkList.empty()) {
const Function *Func = WorkList.pop_back_val();
- if (!Func->user_empty()) {
- const ComputedShaderFlags &FuncSF = getFunctionFlags(Func);
- // Update mask of callers with that of Func
- for (const auto User : Func->users()) {
- if (const CallInst *CI = dyn_cast<CallInst>(User)) {
- const Function *Caller = CI->getParent()->getParent();
- if (mergeFunctionShaderFlags(Caller, FuncSF))
- WorkList.push_back(Caller);
- }
+ if (Func->user_empty())
+ continue;
+
+ const ComputedShaderFlags &FuncSF = getFunctionFlags(Func);
+ // Update mask of callers with that of Func
+ for (const auto User : Func->users()) {
+ if (const CallInst *CI = dyn_cast<CallInst>(User)) {
+ const Function *Caller = CI->getParent()->getParent();
+ if (mergeFunctionShaderFlags(Caller, FuncSF))
+ WorkList.push_back(Caller);
}
}
}
@@ -101,17 +103,10 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
OS << ";\n";
}
-static bool compareShaderFlagsInfo(
- const std::pair<const Function *, ComputedShaderFlags> FSM,
- const Function *FindFunc) {
- return (FSM.first < FindFunc);
-}
-
/// Return the shader flags mask of the specified function Func.
const ComputedShaderFlags &
ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
- const std::pair<const Function *, ComputedShaderFlags> *Iter =
- llvm::lower_bound(FunctionFlags, Func, compareShaderFlagsInfo);
+ auto Iter = FunctionFlags.find(Func);
assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
"Get Shader Flags : No Shader Flags Mask exists for function");
return Iter->second;
@@ -119,17 +114,14 @@ ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
/// Merge specified shader flags mask SF with current mask of the specified
/// function Func.
-/// Return true if merge operation changes the value of shader flags mask of
-/// Func; else false.
+/// Return merge result status viz., true if merge operation changes the value
+/// of shader flags mask of Func; else false.
bool ModuleShaderFlags::mergeFunctionShaderFlags(const Function *Func,
ComputedShaderFlags SF) {
- std::pair<const Function *, ComputedShaderFlags> *Iter =
- llvm::lower_bound(FunctionFlags, Func, compareShaderFlagsInfo);
+ auto Iter = FunctionFlags.find(Func);
assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
"Merge Shader Flags : No Shader Flags Mask exists for function");
- ComputedShaderFlags PreMergeSF = Iter->second;
- Iter->second.merge(SF);
- return (PreMergeSF != Iter->second);
+ return Iter->second.merge(SF);
}
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h
index 158e4ea9033364..e3a2a26367c7da 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.h
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h
@@ -70,12 +70,18 @@ struct ComputedShaderFlags {
return FeatureFlags;
}
- void merge(const ComputedShaderFlags CSF) {
+ /// Return merge result status viz., true if merge operation changes
+ /// shader flags mask; else false.
+ bool merge(const ComputedShaderFlags CSF) {
+ bool Changed = false;
#define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str) \
+ Changed |= (FlagName ^ CSF.FlagName); \
+ FlagName |= CSF.FlagName;
+#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) \
+ Changed |= (FlagName ^ CSF.FlagName); \
FlagName |= CSF.FlagName;
-#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) FlagName |= CSF.FlagName;
#include "llvm/BinaryFormat/DXContainerConstants.def"
- return;
+ return Changed;
}
void print(raw_ostream &OS = dbgs()) const;
@@ -88,10 +94,10 @@ struct ModuleShaderFlags {
const ComputedShaderFlags &getCombinedFlags() const { return CombinedSFMask; }
private:
- /// Vector of sorted Function-Shader Flag mask pairs representing properties
- /// of each of the functions in the module. Shader Flags of each function
- /// represent both module-level and function-level flags
- SmallVector<std::pair<const Function *, ComputedShaderFlags>> FunctionFlags;
+ /// Map of Function-Shader Flag Mask pairs representing properties of each of
+ /// the functions in the module. Shader Flags of each function represent both
+ /// module-level and function-level flags
+ DenseMap<const Function *, ComputedShaderFlags> FunctionFlags;
/// Combined Shader Flag Mask of all functions of the module
ComputedShaderFlags CombinedSFMask{};
bool mergeFunctionShaderFlags(const Function *, ComputedShaderFlags);
>From 0e8e937d40ef76bc3b760581b1c0447f73f648fc Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Thu, 5 Dec 2024 11:55:11 -0500
Subject: [PATCH 4/5] Address PR feedbasck: Use CallGraph and scc_iterator to
discover functions for shader flag computation and propagation.
ComputedShaderFlags::merge() modified to not track and return
changed status as it does not need to provide such functionality.
Update test.
---
llvm/lib/Target/DirectX/DXILShaderFlags.cpp | 103 ++++++++++++------
llvm/lib/Target/DirectX/DXILShaderFlags.h | 15 +--
.../propagate-function-flags-test.ll | 8 +-
3 files changed, 81 insertions(+), 45 deletions(-)
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index 3e1009124ceeca..19cbb35bc76f31 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -13,7 +13,9 @@
#include "DXILShaderFlags.h"
#include "DirectX.h"
+#include "llvm/ADT/SCCIterator.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/Analysis/CallGraph.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
@@ -47,39 +49,76 @@ static void updateFunctionFlags(ComputedShaderFlags &CSF,
}
}
-void ModuleShaderFlags::initialize(const Module &M) {
- SmallVector<const Function *> WorkList;
- // Collect shader flags for each of the functions
- for (const auto &F : M.getFunctionList()) {
- if (F.isDeclaration())
+void ModuleShaderFlags::initialize(Module &M) {
+ CallGraph CG(M);
+
+ // Compute Shader Flags Mask for all functions using post-order visit of SCC
+ // of the call graph.
+ for (scc_iterator<CallGraph *> SCCI = scc_begin(&CG); !SCCI.isAtEnd();
+ ++SCCI) {
+ const std::vector<CallGraphNode *> &CurSCC = *SCCI;
+
+ // Union of shader masks of all functions in CurSCC
+ ComputedShaderFlags SCCSF;
+ for (CallGraphNode *CGN : CurSCC) {
+ Function *F = CGN->getFunction();
+ if (!F)
+ continue;
+
+ if (F->isDeclaration())
+ continue;
+
+ ComputedShaderFlags CSF;
+ for (const auto &BB : *F)
+ for (const auto &I : BB)
+ updateFunctionFlags(CSF, I);
+ // Insert shader flag mask for function F
+ FunctionFlags.insert({F, CSF});
+ // Update combined shader flags mask for all functions of the module
+ CombinedSFMask.merge(CSF);
+ // Update combined shader flags mask for all functions in this SCC
+ SCCSF.merge(CSF);
+ }
+
+ if (CurSCC.size() < 2)
continue;
- if (!F.user_empty())
- WorkList.push_back(&F);
-
- ComputedShaderFlags CSF;
- for (const auto &BB : F)
- for (const auto &I : BB)
- updateFunctionFlags(CSF, I);
- // Insert shader flag mask for function F
- FunctionFlags.insert({&F, CSF});
- // Update combined shader flags mask
- CombinedSFMask.merge(CSF);
+ // Shader flags mask of each of the functions in an SCC of the call graph is
+ // the union of all functions in the SCC. Update shader flags masks of
+ // functions in SCC accordingly.
+ for (CallGraphNode *CGN : CurSCC) {
+ Function *F = CGN->getFunction();
+ if (!F)
+ continue;
+ mergeFunctionShaderFlags(F, SCCSF);
+ }
}
- // Propagate shader flag mask of functions to their callers.
- while (!WorkList.empty()) {
- const Function *Func = WorkList.pop_back_val();
- if (Func->user_empty())
- continue;
-
- const ComputedShaderFlags &FuncSF = getFunctionFlags(Func);
- // Update mask of callers with that of Func
- for (const auto User : Func->users()) {
- if (const CallInst *CI = dyn_cast<CallInst>(User)) {
- const Function *Caller = CI->getParent()->getParent();
- if (mergeFunctionShaderFlags(Caller, FuncSF))
- WorkList.push_back(Caller);
+ // Propagate Shader Flag Masks to callers with another post-order call graph
+ // walk
+ for (scc_iterator<CallGraph *> SCCI = scc_begin(&CG); !SCCI.isAtEnd();
+ ++SCCI) {
+ const std::vector<CallGraphNode *> &CurSCC = *SCCI;
+ for (CallGraphNode *CGN : CurSCC) {
+ Function *F = CGN->getFunction();
+ if (!F)
+ continue;
+
+ if (F->isDeclaration() || F->user_empty())
+ continue;
+
+ const ComputedShaderFlags &FuncSF = getFunctionFlags(F);
+ // Update mask of callers with that of Func
+ for (const auto User : F->users()) {
+ if (const CallInst *CI = dyn_cast<CallInst>(User)) {
+ const Function *Caller = CI->getParent()->getParent();
+ // Do not need to update masks of callers in the current
+ // SCC, as the masks of all functions in the SCC are alreday
+ // the same. However, it is simpler to merge unconditionally
+ // instead of searching for membership of each Caller in the
+ // vector CurSCC to avoid merging.
+ mergeFunctionShaderFlags(Caller, FuncSF);
+ }
}
}
}
@@ -114,14 +153,12 @@ ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
/// Merge specified shader flags mask SF with current mask of the specified
/// function Func.
-/// Return merge result status viz., true if merge operation changes the value
-/// of shader flags mask of Func; else false.
-bool ModuleShaderFlags::mergeFunctionShaderFlags(const Function *Func,
+void ModuleShaderFlags::mergeFunctionShaderFlags(const Function *Func,
ComputedShaderFlags SF) {
auto Iter = FunctionFlags.find(Func);
assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
"Merge Shader Flags : No Shader Flags Mask exists for function");
- return Iter->second.merge(SF);
+ Iter->second.merge(SF);
}
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h
index e3a2a26367c7da..b7837679da016b 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.h
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h
@@ -70,18 +70,11 @@ struct ComputedShaderFlags {
return FeatureFlags;
}
- /// Return merge result status viz., true if merge operation changes
- /// shader flags mask; else false.
- bool merge(const ComputedShaderFlags CSF) {
- bool Changed = false;
+ void merge(const ComputedShaderFlags CSF) {
#define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str) \
- Changed |= (FlagName ^ CSF.FlagName); \
- FlagName |= CSF.FlagName;
-#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) \
- Changed |= (FlagName ^ CSF.FlagName); \
FlagName |= CSF.FlagName;
+#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) FlagName |= CSF.FlagName;
#include "llvm/BinaryFormat/DXContainerConstants.def"
- return Changed;
}
void print(raw_ostream &OS = dbgs()) const;
@@ -89,7 +82,7 @@ struct ComputedShaderFlags {
};
struct ModuleShaderFlags {
- void initialize(const Module &);
+ void initialize(Module &);
const ComputedShaderFlags &getFunctionFlags(const Function *) const;
const ComputedShaderFlags &getCombinedFlags() const { return CombinedSFMask; }
@@ -100,7 +93,7 @@ struct ModuleShaderFlags {
DenseMap<const Function *, ComputedShaderFlags> FunctionFlags;
/// Combined Shader Flag Mask of all functions of the module
ComputedShaderFlags CombinedSFMask{};
- bool mergeFunctionShaderFlags(const Function *, ComputedShaderFlags);
+ void mergeFunctionShaderFlags(const Function *, ComputedShaderFlags);
};
class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> {
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll
index 93d634c0384ae7..914ca021e95cd9 100644
--- a/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll
@@ -43,7 +43,7 @@ define double @call_n2(i64 noundef %0) local_unnamed_addr #0 {
3: ; preds = %1
%4 = add nuw nsw i64 %0, 1
%5 = uitofp i64 %4 to double
- %6 = tail call double @call_n1(double noundef %5)
+ %6 = tail call double @call_n11(double noundef %5)
br label %10
7: ; preds = %1
@@ -56,6 +56,12 @@ define double @call_n2(i64 noundef %0) local_unnamed_addr #0 {
ret double %11
}
+; CHECK: ; Function call_n11 : 0x00000044
+define double @call_n11(double noundef %0) local_unnamed_addr #1 {
+ %2 = tail call double @call_n1(double noundef %0)
+ ret double %2
+}
+
; CHECK: ; Function call_n1 : 0x00000044
define double @call_n1(double noundef %0) local_unnamed_addr #0 {
%2 = fcmp ugt double %0, 5.000000e+00
>From 52f95a7101e056f6ec06ce3917e07d536f005a2a Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Fri, 6 Dec 2024 15:31:05 -0500
Subject: [PATCH 5/5] Address PR feedback Eliminate second loop to propagate
shader flags masks to callers
---
llvm/lib/Target/DirectX/DXILShaderFlags.cpp | 53 +++++++--------------
1 file changed, 18 insertions(+), 35 deletions(-)
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index 19cbb35bc76f31..691428ac5b7dfd 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -72,52 +72,36 @@ void ModuleShaderFlags::initialize(Module &M) {
for (const auto &BB : *F)
for (const auto &I : BB)
updateFunctionFlags(CSF, I);
- // Insert shader flag mask for function F
- FunctionFlags.insert({F, CSF});
- // Update combined shader flags mask for all functions of the module
- CombinedSFMask.merge(CSF);
// Update combined shader flags mask for all functions in this SCC
SCCSF.merge(CSF);
}
- if (CurSCC.size() < 2)
- continue;
+ // Update combined shader flags mask for all functions of the module
+ CombinedSFMask.merge(SCCSF);
// Shader flags mask of each of the functions in an SCC of the call graph is
// the union of all functions in the SCC. Update shader flags masks of
- // functions in SCC accordingly.
- for (CallGraphNode *CGN : CurSCC) {
- Function *F = CGN->getFunction();
- if (!F)
- continue;
- mergeFunctionShaderFlags(F, SCCSF);
- }
- }
-
- // Propagate Shader Flag Masks to callers with another post-order call graph
- // walk
- for (scc_iterator<CallGraph *> SCCI = scc_begin(&CG); !SCCI.isAtEnd();
- ++SCCI) {
- const std::vector<CallGraphNode *> &CurSCC = *SCCI;
+ // functions in SCC accordingly. This is trivially true if SCC contains one
+ // function.
for (CallGraphNode *CGN : CurSCC) {
Function *F = CGN->getFunction();
if (!F)
continue;
-
- if (F->isDeclaration() || F->user_empty())
- continue;
-
- const ComputedShaderFlags &FuncSF = getFunctionFlags(F);
- // Update mask of callers with that of Func
+ // If F already has a shader flag mask associated as result of
+ // any of its callee's flags being propagated, merge SCCSF with
+ // existing flags. Else set its mask to SCCSF.
+ if (FunctionFlags.contains(F))
+ FunctionFlags[F].merge(SCCSF);
+ else
+ FunctionFlags[F] = SCCSF;
+ // Propagate Shader Flag Masks to callers of F
for (const auto User : F->users()) {
if (const CallInst *CI = dyn_cast<CallInst>(User)) {
const Function *Caller = CI->getParent()->getParent();
- // Do not need to update masks of callers in the current
- // SCC, as the masks of all functions in the SCC are alreday
- // the same. However, it is simpler to merge unconditionally
- // instead of searching for membership of each Caller in the
- // vector CurSCC to avoid merging.
- mergeFunctionShaderFlags(Caller, FuncSF);
+ if (FunctionFlags.contains(Caller))
+ FunctionFlags[Caller].merge(SCCSF);
+ else
+ FunctionFlags[Caller] = SCCSF;
}
}
}
@@ -155,10 +139,9 @@ ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
/// function Func.
void ModuleShaderFlags::mergeFunctionShaderFlags(const Function *Func,
ComputedShaderFlags SF) {
- auto Iter = FunctionFlags.find(Func);
- assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
+ assert(FunctionFlags.contains(Func) &&
"Merge Shader Flags : No Shader Flags Mask exists for function");
- Iter->second.merge(SF);
+ FunctionFlags[Func].merge(SF);
}
//===----------------------------------------------------------------------===//
More information about the llvm-commits
mailing list