[llvm] a4b7a2d - [DirectX] Propagate shader flags mask of callees to callers (#118306)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Jan 14 10:18:21 PST 2025
Author: S. Bharadwaj Yadavalli
Date: 2025-01-14T13:18:16-05:00
New Revision: a4b7a2d021ca7371752f0e8180200ffd7b48ca70
URL: https://github.com/llvm/llvm-project/commit/a4b7a2d021ca7371752f0e8180200ffd7b48ca70
DIFF: https://github.com/llvm/llvm-project/commit/a4b7a2d021ca7371752f0e8180200ffd7b48ca70.diff
LOG: [DirectX] Propagate shader flags mask of callees to callers (#118306)
Propagate shader flags mask of callees to callers.
Add tests to verify propagation of shader flags
Added:
llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll
Modified:
llvm/lib/Target/DirectX/DXILShaderFlags.cpp
llvm/lib/Target/DirectX/DXILShaderFlags.h
llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index 2edfc707ce6c79..b1ff975d4dae96 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -13,9 +13,12 @@
#include "DXILShaderFlags.h"
#include "DirectX.h"
-#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SCCIterator.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Analysis/CallGraph.h"
#include "llvm/Analysis/DXILResource.h"
#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsDirectX.h"
@@ -27,15 +30,24 @@
using namespace llvm;
using namespace llvm::dxil;
-static void updateFunctionFlags(ComputedShaderFlags &CSF, const Instruction &I,
- DXILResourceTypeMap &DRTM) {
+/// Update the shader flags mask based on the given instruction.
+/// \param CSF Shader flags mask to update.
+/// \param I Instruction to check.
+void ModuleShaderFlags::updateFunctionFlags(ComputedShaderFlags &CSF,
+ const Instruction &I,
+ DXILResourceTypeMap &DRTM) {
if (!CSF.Doubles)
CSF.Doubles = I.getType()->isDoubleTy();
if (!CSF.Doubles) {
- for (Value *Op : I.operands())
- CSF.Doubles |= Op->getType()->isDoubleTy();
+ for (const Value *Op : I.operands()) {
+ if (Op->getType()->isDoubleTy()) {
+ CSF.Doubles = true;
+ break;
+ }
+ }
}
+
if (CSF.Doubles) {
switch (I.getOpcode()) {
case Instruction::FDiv:
@@ -43,8 +55,6 @@ static void updateFunctionFlags(ComputedShaderFlags &CSF, const Instruction &I,
case Instruction::SIToFP:
case Instruction::FPToUI:
case Instruction::FPToSI:
- // TODO: To be set if I is a call to DXIL intrinsic DXIL::Opcode::Fma
- // https://github.com/llvm/llvm-project/issues/114554
CSF.DX11_1_DoubleExtensions = true;
break;
}
@@ -62,27 +72,65 @@ static void updateFunctionFlags(ComputedShaderFlags &CSF, const Instruction &I,
}
}
}
+ // Handle call instructions
+ if (auto *CI = dyn_cast<CallInst>(&I)) {
+ const Function *CF = CI->getCalledFunction();
+ // Merge-in shader flags mask of the called function in the current module
+ if (FunctionFlags.contains(CF))
+ CSF.merge(FunctionFlags[CF]);
+
+ // TODO: Set DX11_1_DoubleExtensions if I is a call to DXIL intrinsic
+ // DXIL::Opcode::Fma https://github.com/llvm/llvm-project/issues/114554
+ }
}
-void ModuleShaderFlags::initialize(const Module &M, DXILResourceTypeMap &DRTM) {
-
- // Collect shader flags for each of the functions
- for (const auto &F : M.getFunctionList()) {
- if (F.isDeclaration()) {
- assert(!F.getName().starts_with("dx.op.") &&
- "DXIL Shader Flag analysis should not be run post-lowering.");
- continue;
+/// Construct ModuleShaderFlags for module Module M
+void ModuleShaderFlags::initialize(Module &M, DXILResourceTypeMap &DRTM) {
+ CallGraph CG(M);
+
+ // Compute Shader Flags Mask for all functions using post-order visit of SCC
+ // of the call graph.
+ for (scc_iterator<CallGraph *> SCCI = scc_begin(&CG); !SCCI.isAtEnd();
+ ++SCCI) {
+ const std::vector<CallGraphNode *> &CurSCC = *SCCI;
+
+ // Union of shader masks of all functions in CurSCC
+ ComputedShaderFlags SCCSF;
+ // List of functions in CurSCC that are neither external nor declarations
+ // and hence whose flags are collected
+ SmallVector<Function *> CurSCCFuncs;
+ for (CallGraphNode *CGN : CurSCC) {
+ Function *F = CGN->getFunction();
+ if (!F)
+ continue;
+
+ if (F->isDeclaration()) {
+ assert(!F->getName().starts_with("dx.op.") &&
+ "DXIL Shader Flag analysis should not be run post-lowering.");
+ continue;
+ }
+
+ ComputedShaderFlags CSF;
+ for (const auto &BB : *F)
+ for (const auto &I : BB)
+ updateFunctionFlags(CSF, I, DRTM);
+ // Update combined shader flags mask for all functions in this SCC
+ SCCSF.merge(CSF);
+
+ CurSCCFuncs.push_back(F);
}
- ComputedShaderFlags CSF;
- for (const auto &BB : F)
- for (const auto &I : BB)
- updateFunctionFlags(CSF, I, DRTM);
- // Insert shader flag mask for function F
- FunctionFlags.push_back({&F, CSF});
- // Update combined shader flags mask
- CombinedSFMask.merge(CSF);
+
+ // Update combined shader flags mask for all functions of the module
+ CombinedSFMask.merge(SCCSF);
+
+ // Shader flags mask of each of the functions in an SCC of the call graph is
+ // the union of all functions in the SCC. Update shader flags masks of
+ // functions in CurSCC accordingly. This is trivially true if SCC contains
+ // one function.
+ for (Function *F : CurSCCFuncs)
+ // Merge SCCSF with that of F
+ FunctionFlags[F].merge(SCCSF);
}
- llvm::sort(FunctionFlags);
}
void ComputedShaderFlags::print(raw_ostream &OS) const {
@@ -106,12 +154,9 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
/// Return the shader flags mask of the specified function Func.
const ComputedShaderFlags &
ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
- const auto Iter = llvm::lower_bound(
- FunctionFlags, Func,
- [](const std::pair<const Function *, ComputedShaderFlags> FSM,
- const Function *FindFunc) { return (FSM.first < FindFunc); });
+ auto Iter = FunctionFlags.find(Func);
assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
- "No Shader Flags Mask exists for function");
+ "Get Shader Flags : No Shader Flags Mask exists for function");
return Iter->second;
}
@@ -142,7 +187,7 @@ PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
for (const auto &F : M.getFunctionList()) {
if (F.isDeclaration())
continue;
- auto SFMask = FlagsInfo.getFunctionFlags(&F);
+ const ComputedShaderFlags &SFMask = FlagsInfo.getFunctionFlags(&F);
OS << formatv("; Function {0} : {1:x8}\n;\n", F.getName(),
(uint64_t)(SFMask));
}
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h
index 67ddab39d0f349..e6c6d56402c1a7 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.h
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h
@@ -71,13 +71,11 @@ struct ComputedShaderFlags {
return FeatureFlags;
}
- void merge(const uint64_t IVal) {
+ void merge(const ComputedShaderFlags CSF) {
#define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str) \
- FlagName |= (IVal & getMask(DxilModuleBit));
-#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) \
- FlagName |= (IVal & getMask(DxilModuleBit));
+ FlagName |= CSF.FlagName;
+#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) FlagName |= CSF.FlagName;
#include "llvm/BinaryFormat/DXContainerConstants.def"
- return;
}
void print(raw_ostream &OS = dbgs()) const;
@@ -85,17 +83,19 @@ struct ComputedShaderFlags {
};
struct ModuleShaderFlags {
- void initialize(const Module &, DXILResourceTypeMap &DRTM);
+ void initialize(Module &, DXILResourceTypeMap &DRTM);
const ComputedShaderFlags &getFunctionFlags(const Function *) const;
const ComputedShaderFlags &getCombinedFlags() const { return CombinedSFMask; }
private:
- /// Vector of sorted Function-Shader Flag mask pairs representing properties
- /// of each of the functions in the module. Shader Flags of each function
- /// represent both module-level and function-level flags
- SmallVector<std::pair<Function const *, ComputedShaderFlags>> FunctionFlags;
+ /// Map of Function-Shader Flag Mask pairs representing properties of each of
+ /// the functions in the module. Shader Flags of each function represent both
+ /// module-level and function-level flags
+ DenseMap<const Function *, ComputedShaderFlags> FunctionFlags;
/// Combined Shader Flag Mask of all functions of the module
ComputedShaderFlags CombinedSFMask{};
+ void updateFunctionFlags(ComputedShaderFlags &, const Instruction &,
+ DXILResourceTypeMap &);
};
class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> {
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
index 6332ef806a0d8f..d6df67626be5aa 100644
--- a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
@@ -12,6 +12,13 @@ target triple = "dxil-pc-shadermodel6.7-library"
; CHECK-NEXT: ;
; CHECK-NEXT: ; Shader Flags for Module Functions
+;CHECK: ; Function top_level : 0x00000044
+define double @top_level() #0 {
+ %r = call double @test_uitofp_i64(i64 5)
+ ret double %r
+}
+
+
; CHECK: ; Function test_fdiv_double : 0x00000044
define double @test_fdiv_double(double %a, double %b) #0 {
%res = fdiv double %a, %b
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll
new file mode 100644
index 00000000000000..e7a2cf4d5b20f7
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll
@@ -0,0 +1,167 @@
+; RUN: opt -S --passes="print-dx-shader-flags" 2>&1 %s | FileCheck %s
+
+target triple = "dxil-pc-shadermodel6.7-library"
+
+; CHECK: ; Combined Shader Flags for Module
+; CHECK-NEXT: ; Shader Flags Value: 0x00000044
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Note: shader requires additional functionality:
+; CHECK-NEXT: ; Double-precision floating point
+; CHECK-NEXT: ; Double-precision extensions for 11.1
+; CHECK-NEXT: ; Note: extra DXIL module flags:
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Shader Flags for Module Functions
+
+; Call Graph of test source
+; main -> [get_fptoui_flag, get_sitofp_fdiv_flag]
+; get_fptoui_flag -> [get_sitofp_uitofp_flag, call_get_uitofp_flag]
+; get_sitofp_uitofp_flag -> [call_get_fptoui_flag, call_get_sitofp_flag]
+; call_get_fptoui_flag -> [get_fptoui_flag]
+; get_sitofp_fdiv_flag -> [get_no_flags, get_all_doubles_flags]
+; get_all_doubles_flags -> [call_get_sitofp_fdiv_flag]
+; call_get_sitofp_fdiv_flag -> [get_sitofp_fdiv_flag]
+; call_get_sitofp_flag -> [get_sitofp_flag]
+; call_get_uitofp_flag -> [get_uitofp_flag]
+; get_sitofp_flag -> []
+; get_uitofp_flag -> []
+; get_no_flags -> []
+;
+; Strongly Connected Component in the CG
+; [get_fptoui_flag, get_sitofp_uitofp_flag, call_get_fptoui_flag]
+; [get_sitofp_fdiv_flag, get_all_doubles_flags, call_get_sitofp_fdiv_flag]
+
+;
+; CHECK: ; Function get_sitofp_flag : 0x00000044
+define double @get_sitofp_flag(i32 noundef %0) local_unnamed_addr #0 {
+ %2 = sitofp i32 %0 to double
+ ret double %2
+}
+
+; CHECK: ; Function call_get_sitofp_flag : 0x00000044
+define double @call_get_sitofp_flag(i32 noundef %0) local_unnamed_addr #0 {
+ %2 = tail call double @get_sitofp_flag(i32 noundef %0)
+ ret double %2
+}
+
+; CHECK: ; Function get_uitofp_flag : 0x00000044
+define double @get_uitofp_flag(i32 noundef %0) local_unnamed_addr #0 {
+ %2 = uitofp i32 %0 to double
+ ret double %2
+}
+
+; CHECK: ; Function call_get_uitofp_flag : 0x00000044
+define double @call_get_uitofp_flag(i32 noundef %0) local_unnamed_addr #0 {
+ %2 = tail call double @get_uitofp_flag(i32 noundef %0)
+ ret double %2
+}
+
+; CHECK: ; Function call_get_fptoui_flag : 0x00000044
+define double @call_get_fptoui_flag(double noundef %0) local_unnamed_addr #0 {
+ %2 = tail call double @get_fptoui_flag(double noundef %0)
+ ret double %2
+}
+
+; CHECK: ; Function get_fptoui_flag : 0x00000044
+define double @get_fptoui_flag(double noundef %0) local_unnamed_addr #0 {
+ %2 = fcmp ugt double %0, 5.000000e+00
+ br i1 %2, label %6, label %3
+
+3: ; preds = %1
+ %4 = fptoui double %0 to i64
+ %5 = tail call double @get_sitofp_uitofp_flag(i64 noundef %4)
+ br label %9
+
+6: ; preds = %1
+ %7 = fptoui double %0 to i32
+ %8 = tail call double @call_get_uitofp_flag(i32 noundef %7)
+ br label %9
+
+9: ; preds = %6, %3
+ %10 = phi double [ %5, %3 ], [ %8, %6 ]
+ ret double %10
+}
+
+; CHECK: ; Function get_sitofp_uitofp_flag : 0x00000044
+define double @get_sitofp_uitofp_flag(i64 noundef %0) local_unnamed_addr #0 {
+ %2 = icmp ult i64 %0, 6
+ br i1 %2, label %3, label %7
+
+3: ; preds = %1
+ %4 = add nuw nsw i64 %0, 1
+ %5 = uitofp i64 %4 to double
+ %6 = tail call double @call_get_fptoui_flag(double noundef %5)
+ br label %10
+
+7: ; preds = %1
+ %8 = trunc i64 %0 to i32
+ %9 = tail call double @call_get_sitofp_flag(i32 noundef %8)
+ br label %10
+
+10: ; preds = %7, %3
+ %11 = phi double [ %6, %3 ], [ %9, %7 ]
+ ret double %11
+}
+
+; CHECK: ; Function get_no_flags : 0x00000000
+define i32 @get_no_flags(i32 noundef %0) local_unnamed_addr #0 {
+ %2 = mul nsw i32 %0, %0
+ ret i32 %2
+}
+
+; CHECK: ; Function call_get_sitofp_fdiv_flag : 0x00000044
+define i32 @call_get_sitofp_fdiv_flag(i32 noundef %0) local_unnamed_addr #0 {
+ %2 = icmp eq i32 %0, 0
+ br i1 %2, label %5, label %3
+
+3: ; preds = %1
+ %4 = mul nsw i32 %0, %0
+ br label %7
+
+5: ; preds = %1
+ %6 = tail call double @get_sitofp_fdiv_flag(i32 noundef 0)
+ br label %7
+
+7: ; preds = %5, %3
+ %8 = phi i32 [ %4, %3 ], [ 0, %5 ]
+ ret i32 %8
+}
+
+; CHECK: ; Function get_sitofp_fdiv_flag : 0x00000044
+define double @get_sitofp_fdiv_flag(i32 noundef %0) local_unnamed_addr #0 {
+ %2 = icmp sgt i32 %0, 5
+ br i1 %2, label %3, label %6
+
+3: ; preds = %1
+ %4 = tail call i32 @get_no_flags(i32 noundef %0)
+ %5 = sitofp i32 %4 to double
+ br label %9
+
+6: ; preds = %1
+ %7 = tail call double @get_all_doubles_flags(i32 noundef %0)
+ %8 = fdiv double %7, 3.000000e+00
+ br label %9
+
+9: ; preds = %6, %3
+ %10 = phi double [ %5, %3 ], [ %8, %6 ]
+ ret double %10
+}
+
+; CHECK: ; Function get_all_doubles_flags : 0x00000044
+define double @get_all_doubles_flags(i32 noundef %0) local_unnamed_addr #0 {
+ %2 = tail call i32 @call_get_sitofp_fdiv_flag(i32 noundef %0)
+ %3 = icmp eq i32 %2, 0
+ %4 = select i1 %3, double 1.000000e+01, double 1.000000e+02
+ ret double %4
+}
+
+; CHECK: ; Function main : 0x00000044
+define i32 @main() local_unnamed_addr #0 {
+ %1 = tail call double @get_fptoui_flag(double noundef 1.000000e+00)
+ %2 = tail call double @get_sitofp_fdiv_flag(i32 noundef 4)
+ %3 = fadd double %1, %2
+ %4 = fcmp ogt double %3, 0.000000e+00
+ %5 = zext i1 %4 to i32
+ ret i32 %5
+}
+
+attributes #0 = { convergent norecurse nounwind "hlsl.export"}
More information about the llvm-commits
mailing list