[llvm] [DirectX] Propagate shader flags mask of callees to callers (PR #118306)
S. Bharadwaj Yadavalli via llvm-commits
llvm-commits at lists.llvm.org
Wed Dec 4 16:52:02 PST 2024
https://github.com/bharadwajy updated https://github.com/llvm/llvm-project/pull/118306
>From 6a4fef8b1c8ad78db63016ddc927bbbebeea04c5 Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Wed, 27 Nov 2024 11:25:34 -0500
Subject: [PATCH 1/3] Propagate shader flags mask of callees to callers
Add test to verify propagation of shader flags
---
llvm/lib/Target/DirectX/DXILShaderFlags.cpp | 43 ++++++++-
llvm/lib/Target/DirectX/DXILShaderFlags.h | 2 +
.../DirectX/ShaderFlags/double-extensions.ll | 7 ++
.../propagate-function-flags-test.ll | 92 +++++++++++++++++++
4 files changed, 140 insertions(+), 4 deletions(-)
create mode 100644 llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index d6917dce98abd5..f242204363cfe8 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -15,6 +15,7 @@
#include "DirectX.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
@@ -47,10 +48,14 @@ static void updateFunctionFlags(ComputedShaderFlags &CSF,
}
void ModuleShaderFlags::initialize(const Module &M) {
+ SmallVector<const Function *> WorkList;
// Collect shader flags for each of the functions
for (const auto &F : M.getFunctionList()) {
if (F.isDeclaration())
continue;
+ if (!F.user_empty()) {
+ WorkList.push_back(&F);
+ }
ComputedShaderFlags CSF;
for (const auto &BB : F)
for (const auto &I : BB)
@@ -61,6 +66,21 @@ void ModuleShaderFlags::initialize(const Module &M) {
CombinedSFMask.merge(CSF);
}
llvm::sort(FunctionFlags);
+ // Propagate shader flag mask of functions to their callers.
+ while (!WorkList.empty()) {
+ const Function *Func = WorkList.pop_back_val();
+ if (!Func->user_empty()) {
+ ComputedShaderFlags FuncSF = getFunctionFlags(Func);
+ // Update mask of callers with that of Func
+ for (const auto User : Func->users()) {
+ if (const CallInst *CI = dyn_cast<CallInst>(User)) {
+ const Function *Caller = CI->getParent()->getParent();
+ if (mergeFunctionShaderFlags(Caller, FuncSF))
+ WorkList.push_back(Caller);
+ }
+ }
+ }
+ }
}
void ComputedShaderFlags::print(raw_ostream &OS) const {
@@ -81,16 +101,31 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
OS << ";\n";
}
-/// Return the shader flags mask of the specified function Func.
-const ComputedShaderFlags &
-ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
+auto ModuleShaderFlags::getFunctionShaderFlagInfo(const Function *Func) const {
const auto Iter = llvm::lower_bound(
FunctionFlags, Func,
[](const std::pair<const Function *, ComputedShaderFlags> FSM,
const Function *FindFunc) { return (FSM.first < FindFunc); });
assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
"No Shader Flags Mask exists for function");
- return Iter->second;
+ return Iter;
+}
+
+/// Merge mask NewSF to that of Func, if different.
+/// Return true if mask of Func is changed, else false.
+bool ModuleShaderFlags::mergeFunctionShaderFlags(
+ const Function *Func, const ComputedShaderFlags NewSF) {
+ const auto FuncSFInfo = getFunctionShaderFlagInfo(Func);
+ if ((FuncSFInfo->second & NewSF) != NewSF) {
+ const_cast<ComputedShaderFlags &>(FuncSFInfo->second).merge(NewSF);
+ return true;
+ }
+ return false;
+}
+/// Return the shader flags mask of the specified function Func.
+const ComputedShaderFlags &
+ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
+ return getFunctionShaderFlagInfo(Func)->second;
}
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h
index 2d60137f8b191c..8c581f243ca98b 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.h
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h
@@ -95,6 +95,8 @@ struct ModuleShaderFlags {
SmallVector<std::pair<Function const *, ComputedShaderFlags>> FunctionFlags;
/// Combined Shader Flag Mask of all functions of the module
ComputedShaderFlags CombinedSFMask{};
+ auto getFunctionShaderFlagInfo(const Function *) const;
+ bool mergeFunctionShaderFlags(const Function *, ComputedShaderFlags);
};
class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> {
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
index 6332ef806a0d8f..8e5e61b42469ad 100644
--- a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
@@ -12,6 +12,13 @@ target triple = "dxil-pc-shadermodel6.7-library"
; CHECK-NEXT: ;
; CHECK-NEXT: ; Shader Flags for Module Functions
+;CHECK: ; Function top_level : 0x00000044
+define void @top_level() #0 {
+ call void @test_uitofp_i64(i64 noundef 5)
+ ret void
+}
+
+
; CHECK: ; Function test_fdiv_double : 0x00000044
define double @test_fdiv_double(double %a, double %b) #0 {
%res = fdiv double %a, %b
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll
new file mode 100644
index 00000000000000..93d634c0384ae7
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll
@@ -0,0 +1,92 @@
+; RUN: opt -S --passes="print-dx-shader-flags" 2>&1 %s | FileCheck %s
+
+target triple = "dxil-pc-shadermodel6.7-library"
+
+; CHECK: ; Combined Shader Flags for Module
+; CHECK-NEXT: ; Shader Flags Value: 0x00000044
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Note: shader requires additional functionality:
+; CHECK-NEXT: ; Double-precision floating point
+; CHECK-NEXT: ; Double-precision extensions for 11.1
+; CHECK-NEXT: ; Note: extra DXIL module flags:
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Shader Flags for Module Functions
+
+; CHECK: ; Function call_n6 : 0x00000044
+define double @call_n6(i32 noundef %0) local_unnamed_addr #0 {
+ %2 = sitofp i32 %0 to double
+ ret double %2
+}
+; CHECK: ; Function call_n4 : 0x00000044
+define double @call_n4(i32 noundef %0) local_unnamed_addr #0 {
+ %2 = tail call double @call_n6(i32 noundef %0)
+ ret double %2
+}
+
+; CHECK: ; Function call_n7 : 0x00000044
+define double @call_n7(i32 noundef %0) local_unnamed_addr #0 {
+ %2 = uitofp i32 %0 to double
+ ret double %2
+}
+
+; CHECK: ; Function call_n5 : 0x00000044
+define double @call_n5(i32 noundef %0) local_unnamed_addr #0 {
+ %2 = tail call double @call_n7(i32 noundef %0)
+ ret double %2
+}
+
+; CHECK: ; Function call_n2 : 0x00000044
+define double @call_n2(i64 noundef %0) local_unnamed_addr #0 {
+ %2 = icmp ult i64 %0, 6
+ br i1 %2, label %3, label %7
+
+3: ; preds = %1
+ %4 = add nuw nsw i64 %0, 1
+ %5 = uitofp i64 %4 to double
+ %6 = tail call double @call_n1(double noundef %5)
+ br label %10
+
+7: ; preds = %1
+ %8 = trunc i64 %0 to i32
+ %9 = tail call double @call_n4(i32 noundef %8)
+ br label %10
+
+10: ; preds = %7, %3
+ %11 = phi double [ %6, %3 ], [ %9, %7 ]
+ ret double %11
+}
+
+; CHECK: ; Function call_n1 : 0x00000044
+define double @call_n1(double noundef %0) local_unnamed_addr #0 {
+ %2 = fcmp ugt double %0, 5.000000e+00
+ br i1 %2, label %6, label %3
+
+3: ; preds = %1
+ %4 = fptoui double %0 to i64
+ %5 = tail call double @call_n2(i64 noundef %4)
+ br label %9
+
+6: ; preds = %1
+ %7 = fptoui double %0 to i32
+ %8 = tail call double @call_n5(i32 noundef %7)
+ br label %9
+
+9: ; preds = %6, %3
+ %10 = phi double [ %5, %3 ], [ %8, %6 ]
+ ret double %10
+}
+
+; CHECK: ; Function call_n3 : 0x00000044
+define double @call_n3(double noundef %0) local_unnamed_addr #0 {
+ %2 = fdiv double %0, 3.000000e+00
+ ret double %2
+}
+
+; CHECK: ; Function main : 0x00000044
+define i32 @main() local_unnamed_addr #0 {
+ %1 = tail call double @call_n1(double noundef 1.000000e+00)
+ %2 = tail call double @call_n3(double noundef %1)
+ ret i32 0
+}
+
+attributes #0 = { convergent norecurse nounwind "hlsl.export"}
>From a360efc5d0a90b4bf36384fa779df01c07bba1fd Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Tue, 3 Dec 2024 19:52:07 -0500
Subject: [PATCH 2/3] Address PR feedback
---
llvm/lib/Target/DirectX/DXILShaderFlags.cpp | 48 +++++++++++----------
llvm/lib/Target/DirectX/DXILShaderFlags.h | 10 ++---
2 files changed, 30 insertions(+), 28 deletions(-)
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index f242204363cfe8..44fa5a2bb060b6 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -70,7 +70,7 @@ void ModuleShaderFlags::initialize(const Module &M) {
while (!WorkList.empty()) {
const Function *Func = WorkList.pop_back_val();
if (!Func->user_empty()) {
- ComputedShaderFlags FuncSF = getFunctionFlags(Func);
+ const ComputedShaderFlags &FuncSF = getFunctionFlags(Func);
// Update mask of callers with that of Func
for (const auto User : Func->users()) {
if (const CallInst *CI = dyn_cast<CallInst>(User)) {
@@ -101,31 +101,35 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
OS << ";\n";
}
-auto ModuleShaderFlags::getFunctionShaderFlagInfo(const Function *Func) const {
- const auto Iter = llvm::lower_bound(
- FunctionFlags, Func,
- [](const std::pair<const Function *, ComputedShaderFlags> FSM,
- const Function *FindFunc) { return (FSM.first < FindFunc); });
- assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
- "No Shader Flags Mask exists for function");
- return Iter;
+static bool compareShaderFlagsInfo(
+ const std::pair<const Function *, ComputedShaderFlags> FSM,
+ const Function *FindFunc) {
+ return (FSM.first < FindFunc);
}
-/// Merge mask NewSF to that of Func, if different.
-/// Return true if mask of Func is changed, else false.
-bool ModuleShaderFlags::mergeFunctionShaderFlags(
- const Function *Func, const ComputedShaderFlags NewSF) {
- const auto FuncSFInfo = getFunctionShaderFlagInfo(Func);
- if ((FuncSFInfo->second & NewSF) != NewSF) {
- const_cast<ComputedShaderFlags &>(FuncSFInfo->second).merge(NewSF);
- return true;
- }
- return false;
-}
/// Return the shader flags mask of the specified function Func.
const ComputedShaderFlags &
ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
- return getFunctionShaderFlagInfo(Func)->second;
+ const std::pair<const Function *, ComputedShaderFlags> *Iter =
+ llvm::lower_bound(FunctionFlags, Func, compareShaderFlagsInfo);
+ assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
+ "Get Shader Flags : No Shader Flags Mask exists for function");
+ return Iter->second;
+}
+
+/// Merge specified shader flags mask SF with current mask of the specified
+/// function Func.
+/// Return true if merge operation changes the value of shader flags mask of
+/// Func; else false.
+bool ModuleShaderFlags::mergeFunctionShaderFlags(const Function *Func,
+ ComputedShaderFlags SF) {
+ std::pair<const Function *, ComputedShaderFlags> *Iter =
+ llvm::lower_bound(FunctionFlags, Func, compareShaderFlagsInfo);
+ assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
+ "Merge Shader Flags : No Shader Flags Mask exists for function");
+ ComputedShaderFlags PreMergeSF = Iter->second;
+ Iter->second.merge(SF);
+ return (PreMergeSF != Iter->second);
}
//===----------------------------------------------------------------------===//
@@ -152,7 +156,7 @@ PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
for (const auto &F : M.getFunctionList()) {
if (F.isDeclaration())
continue;
- auto SFMask = FlagsInfo.getFunctionFlags(&F);
+ const ComputedShaderFlags &SFMask = FlagsInfo.getFunctionFlags(&F);
OS << formatv("; Function {0} : {1:x8}\n;\n", F.getName(),
(uint64_t)(SFMask));
}
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h
index 8c581f243ca98b..158e4ea9033364 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.h
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h
@@ -70,11 +70,10 @@ struct ComputedShaderFlags {
return FeatureFlags;
}
- void merge(const uint64_t IVal) {
+ void merge(const ComputedShaderFlags CSF) {
#define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str) \
- FlagName |= (IVal & getMask(DxilModuleBit));
-#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) \
- FlagName |= (IVal & getMask(DxilModuleBit));
+ FlagName |= CSF.FlagName;
+#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) FlagName |= CSF.FlagName;
#include "llvm/BinaryFormat/DXContainerConstants.def"
return;
}
@@ -92,10 +91,9 @@ struct ModuleShaderFlags {
/// Vector of sorted Function-Shader Flag mask pairs representing properties
/// of each of the functions in the module. Shader Flags of each function
/// represent both module-level and function-level flags
- SmallVector<std::pair<Function const *, ComputedShaderFlags>> FunctionFlags;
+ SmallVector<std::pair<const Function *, ComputedShaderFlags>> FunctionFlags;
/// Combined Shader Flag Mask of all functions of the module
ComputedShaderFlags CombinedSFMask{};
- auto getFunctionShaderFlagInfo(const Function *) const;
bool mergeFunctionShaderFlags(const Function *, ComputedShaderFlags);
};
>From b7cb3cd12a5816f55037e55cb981e10cdc462fbf Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Wed, 4 Dec 2024 15:07:11 -0500
Subject: [PATCH 3/3] Address PR feedback - Use DenseMap instead of SmallVector
for Function-Shader Flags Mask map - Return true if
ComputedShaderFlags::merge() if value is modified - Style nits
---
llvm/lib/Target/DirectX/DXILShaderFlags.cpp | 48 +++++++++------------
llvm/lib/Target/DirectX/DXILShaderFlags.h | 20 ++++++---
2 files changed, 33 insertions(+), 35 deletions(-)
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index 44fa5a2bb060b6..3e1009124ceeca 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -53,31 +53,33 @@ void ModuleShaderFlags::initialize(const Module &M) {
for (const auto &F : M.getFunctionList()) {
if (F.isDeclaration())
continue;
- if (!F.user_empty()) {
+
+ if (!F.user_empty())
WorkList.push_back(&F);
- }
+
ComputedShaderFlags CSF;
for (const auto &BB : F)
for (const auto &I : BB)
updateFunctionFlags(CSF, I);
// Insert shader flag mask for function F
- FunctionFlags.push_back({&F, CSF});
+ FunctionFlags.insert({&F, CSF});
// Update combined shader flags mask
CombinedSFMask.merge(CSF);
}
- llvm::sort(FunctionFlags);
+
// Propagate shader flag mask of functions to their callers.
while (!WorkList.empty()) {
const Function *Func = WorkList.pop_back_val();
- if (!Func->user_empty()) {
- const ComputedShaderFlags &FuncSF = getFunctionFlags(Func);
- // Update mask of callers with that of Func
- for (const auto User : Func->users()) {
- if (const CallInst *CI = dyn_cast<CallInst>(User)) {
- const Function *Caller = CI->getParent()->getParent();
- if (mergeFunctionShaderFlags(Caller, FuncSF))
- WorkList.push_back(Caller);
- }
+ if (Func->user_empty())
+ continue;
+
+ const ComputedShaderFlags &FuncSF = getFunctionFlags(Func);
+ // Update mask of callers with that of Func
+ for (const auto User : Func->users()) {
+ if (const CallInst *CI = dyn_cast<CallInst>(User)) {
+ const Function *Caller = CI->getParent()->getParent();
+ if (mergeFunctionShaderFlags(Caller, FuncSF))
+ WorkList.push_back(Caller);
}
}
}
@@ -101,17 +103,10 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
OS << ";\n";
}
-static bool compareShaderFlagsInfo(
- const std::pair<const Function *, ComputedShaderFlags> FSM,
- const Function *FindFunc) {
- return (FSM.first < FindFunc);
-}
-
/// Return the shader flags mask of the specified function Func.
const ComputedShaderFlags &
ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
- const std::pair<const Function *, ComputedShaderFlags> *Iter =
- llvm::lower_bound(FunctionFlags, Func, compareShaderFlagsInfo);
+ auto Iter = FunctionFlags.find(Func);
assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
"Get Shader Flags : No Shader Flags Mask exists for function");
return Iter->second;
@@ -119,17 +114,14 @@ ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
/// Merge specified shader flags mask SF with current mask of the specified
/// function Func.
-/// Return true if merge operation changes the value of shader flags mask of
-/// Func; else false.
+/// Return merge result status viz., true if merge operation changes the value
+/// of shader flags mask of Func; else false.
bool ModuleShaderFlags::mergeFunctionShaderFlags(const Function *Func,
ComputedShaderFlags SF) {
- std::pair<const Function *, ComputedShaderFlags> *Iter =
- llvm::lower_bound(FunctionFlags, Func, compareShaderFlagsInfo);
+ auto Iter = FunctionFlags.find(Func);
assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
"Merge Shader Flags : No Shader Flags Mask exists for function");
- ComputedShaderFlags PreMergeSF = Iter->second;
- Iter->second.merge(SF);
- return (PreMergeSF != Iter->second);
+ return Iter->second.merge(SF);
}
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h
index 158e4ea9033364..e3a2a26367c7da 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.h
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h
@@ -70,12 +70,18 @@ struct ComputedShaderFlags {
return FeatureFlags;
}
- void merge(const ComputedShaderFlags CSF) {
+ /// Return merge result status viz., true if merge operation changes
+ /// shader flags mask; else false.
+ bool merge(const ComputedShaderFlags CSF) {
+ bool Changed = false;
#define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str) \
+ Changed |= (FlagName ^ CSF.FlagName); \
+ FlagName |= CSF.FlagName;
+#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) \
+ Changed |= (FlagName ^ CSF.FlagName); \
FlagName |= CSF.FlagName;
-#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) FlagName |= CSF.FlagName;
#include "llvm/BinaryFormat/DXContainerConstants.def"
- return;
+ return Changed;
}
void print(raw_ostream &OS = dbgs()) const;
@@ -88,10 +94,10 @@ struct ModuleShaderFlags {
const ComputedShaderFlags &getCombinedFlags() const { return CombinedSFMask; }
private:
- /// Vector of sorted Function-Shader Flag mask pairs representing properties
- /// of each of the functions in the module. Shader Flags of each function
- /// represent both module-level and function-level flags
- SmallVector<std::pair<const Function *, ComputedShaderFlags>> FunctionFlags;
+ /// Map of Function-Shader Flag Mask pairs representing properties of each of
+ /// the functions in the module. Shader Flags of each function represent both
+ /// module-level and function-level flags
+ DenseMap<const Function *, ComputedShaderFlags> FunctionFlags;
/// Combined Shader Flag Mask of all functions of the module
ComputedShaderFlags CombinedSFMask{};
bool mergeFunctionShaderFlags(const Function *, ComputedShaderFlags);
More information about the llvm-commits
mailing list