[llvm] [DirectX] Propagate shader flags mask of callees to callers (PR #118306)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Dec 2 07:02:22 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-directx
Author: S. Bharadwaj Yadavalli (bharadwajy)
<details>
<summary>Changes</summary>
Propagate shader flags mask of callees to callers.
Add test to verify propagation of shader flags
---
Full diff: https://github.com/llvm/llvm-project/pull/118306.diff
4 Files Affected:
- (modified) llvm/lib/Target/DirectX/DXILShaderFlags.cpp (+39-4)
- (modified) llvm/lib/Target/DirectX/DXILShaderFlags.h (+2)
- (modified) llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll (+7)
- (added) llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll (+92)
``````````diff
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index d6917dce98abd5..f242204363cfe8 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -15,6 +15,7 @@
#include "DirectX.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
@@ -47,10 +48,14 @@ static void updateFunctionFlags(ComputedShaderFlags &CSF,
}
void ModuleShaderFlags::initialize(const Module &M) {
+ SmallVector<const Function *> WorkList;
// Collect shader flags for each of the functions
for (const auto &F : M.getFunctionList()) {
if (F.isDeclaration())
continue;
+ if (!F.user_empty()) {
+ WorkList.push_back(&F);
+ }
ComputedShaderFlags CSF;
for (const auto &BB : F)
for (const auto &I : BB)
@@ -61,6 +66,21 @@ void ModuleShaderFlags::initialize(const Module &M) {
CombinedSFMask.merge(CSF);
}
llvm::sort(FunctionFlags);
+ // Propagate shader flag mask of functions to their callers.
+ while (!WorkList.empty()) {
+ const Function *Func = WorkList.pop_back_val();
+ if (!Func->user_empty()) {
+ ComputedShaderFlags FuncSF = getFunctionFlags(Func);
+ // Update mask of callers with that of Func
+ for (const auto User : Func->users()) {
+ if (const CallInst *CI = dyn_cast<CallInst>(User)) {
+ const Function *Caller = CI->getParent()->getParent();
+ if (mergeFunctionShaderFlags(Caller, FuncSF))
+ WorkList.push_back(Caller);
+ }
+ }
+ }
+ }
}
void ComputedShaderFlags::print(raw_ostream &OS) const {
@@ -81,16 +101,31 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
OS << ";\n";
}
-/// Return the shader flags mask of the specified function Func.
-const ComputedShaderFlags &
-ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
+auto ModuleShaderFlags::getFunctionShaderFlagInfo(const Function *Func) const {
const auto Iter = llvm::lower_bound(
FunctionFlags, Func,
[](const std::pair<const Function *, ComputedShaderFlags> FSM,
const Function *FindFunc) { return (FSM.first < FindFunc); });
assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
"No Shader Flags Mask exists for function");
- return Iter->second;
+ return Iter;
+}
+
+/// Merge mask NewSF to that of Func, if different.
+/// Return true if mask of Func is changed, else false.
+bool ModuleShaderFlags::mergeFunctionShaderFlags(
+ const Function *Func, const ComputedShaderFlags NewSF) {
+ const auto FuncSFInfo = getFunctionShaderFlagInfo(Func);
+ if ((FuncSFInfo->second & NewSF) != NewSF) {
+ const_cast<ComputedShaderFlags &>(FuncSFInfo->second).merge(NewSF);
+ return true;
+ }
+ return false;
+}
+/// Return the shader flags mask of the specified function Func.
+const ComputedShaderFlags &
+ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
+ return getFunctionShaderFlagInfo(Func)->second;
}
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h
index 2d60137f8b191c..8c581f243ca98b 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.h
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h
@@ -95,6 +95,8 @@ struct ModuleShaderFlags {
SmallVector<std::pair<Function const *, ComputedShaderFlags>> FunctionFlags;
/// Combined Shader Flag Mask of all functions of the module
ComputedShaderFlags CombinedSFMask{};
+ auto getFunctionShaderFlagInfo(const Function *) const;
+ bool mergeFunctionShaderFlags(const Function *, ComputedShaderFlags);
};
class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> {
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
index 6332ef806a0d8f..8e5e61b42469ad 100644
--- a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
@@ -12,6 +12,13 @@ target triple = "dxil-pc-shadermodel6.7-library"
; CHECK-NEXT: ;
; CHECK-NEXT: ; Shader Flags for Module Functions
+;CHECK: ; Function top_level : 0x00000044
+define void @top_level() #0 {
+ call void @test_uitofp_i64(i64 noundef 5)
+ ret void
+}
+
+
; CHECK: ; Function test_fdiv_double : 0x00000044
define double @test_fdiv_double(double %a, double %b) #0 {
%res = fdiv double %a, %b
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll
new file mode 100644
index 00000000000000..93d634c0384ae7
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll
@@ -0,0 +1,92 @@
+; RUN: opt -S --passes="print-dx-shader-flags" 2>&1 %s | FileCheck %s
+
+target triple = "dxil-pc-shadermodel6.7-library"
+
+; CHECK: ; Combined Shader Flags for Module
+; CHECK-NEXT: ; Shader Flags Value: 0x00000044
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Note: shader requires additional functionality:
+; CHECK-NEXT: ; Double-precision floating point
+; CHECK-NEXT: ; Double-precision extensions for 11.1
+; CHECK-NEXT: ; Note: extra DXIL module flags:
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Shader Flags for Module Functions
+
+; CHECK: ; Function call_n6 : 0x00000044
+define double @call_n6(i32 noundef %0) local_unnamed_addr #0 {
+ %2 = sitofp i32 %0 to double
+ ret double %2
+}
+; CHECK: ; Function call_n4 : 0x00000044
+define double @call_n4(i32 noundef %0) local_unnamed_addr #0 {
+ %2 = tail call double @call_n6(i32 noundef %0)
+ ret double %2
+}
+
+; CHECK: ; Function call_n7 : 0x00000044
+define double @call_n7(i32 noundef %0) local_unnamed_addr #0 {
+ %2 = uitofp i32 %0 to double
+ ret double %2
+}
+
+; CHECK: ; Function call_n5 : 0x00000044
+define double @call_n5(i32 noundef %0) local_unnamed_addr #0 {
+ %2 = tail call double @call_n7(i32 noundef %0)
+ ret double %2
+}
+
+; CHECK: ; Function call_n2 : 0x00000044
+define double @call_n2(i64 noundef %0) local_unnamed_addr #0 {
+ %2 = icmp ult i64 %0, 6
+ br i1 %2, label %3, label %7
+
+3: ; preds = %1
+ %4 = add nuw nsw i64 %0, 1
+ %5 = uitofp i64 %4 to double
+ %6 = tail call double @call_n1(double noundef %5)
+ br label %10
+
+7: ; preds = %1
+ %8 = trunc i64 %0 to i32
+ %9 = tail call double @call_n4(i32 noundef %8)
+ br label %10
+
+10: ; preds = %7, %3
+ %11 = phi double [ %6, %3 ], [ %9, %7 ]
+ ret double %11
+}
+
+; CHECK: ; Function call_n1 : 0x00000044
+define double @call_n1(double noundef %0) local_unnamed_addr #0 {
+ %2 = fcmp ugt double %0, 5.000000e+00
+ br i1 %2, label %6, label %3
+
+3: ; preds = %1
+ %4 = fptoui double %0 to i64
+ %5 = tail call double @call_n2(i64 noundef %4)
+ br label %9
+
+6: ; preds = %1
+ %7 = fptoui double %0 to i32
+ %8 = tail call double @call_n5(i32 noundef %7)
+ br label %9
+
+9: ; preds = %6, %3
+ %10 = phi double [ %5, %3 ], [ %8, %6 ]
+ ret double %10
+}
+
+; CHECK: ; Function call_n3 : 0x00000044
+define double @call_n3(double noundef %0) local_unnamed_addr #0 {
+ %2 = fdiv double %0, 3.000000e+00
+ ret double %2
+}
+
+; CHECK: ; Function main : 0x00000044
+define i32 @main() local_unnamed_addr #0 {
+ %1 = tail call double @call_n1(double noundef 1.000000e+00)
+ %2 = tail call double @call_n3(double noundef %1)
+ ret i32 0
+}
+
+attributes #0 = { convergent norecurse nounwind "hlsl.export"}
``````````
</details>
https://github.com/llvm/llvm-project/pull/118306
More information about the llvm-commits
mailing list