[llvm] [DirectX] Propagate shader flags mask of callees to callers (PR #118306)

S. Bharadwaj Yadavalli via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 14 09:14:19 PST 2025


https://github.com/bharadwajy updated https://github.com/llvm/llvm-project/pull/118306

>From b6dfe53cf7476cf64f2edf21361e43e4a3f3a9ff Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Wed, 27 Nov 2024 11:25:34 -0500
Subject: [PATCH 1/3] [Rebase] Propagate shader flags mask of callees to
 callers

Add tests to verify propagation of shader flags
---
 llvm/lib/Target/DirectX/DXILShaderFlags.cpp   | 104 ++++++++---
 llvm/lib/Target/DirectX/DXILShaderFlags.h     |  20 +--
 .../DirectX/ShaderFlags/double-extensions.ll  |   7 +
 .../propagate-function-flags-test.ll          | 167 ++++++++++++++++++
 4 files changed, 259 insertions(+), 39 deletions(-)
 create mode 100644 llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll

diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index 2edfc707ce6c79..e956189f8ecd4e 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -13,9 +13,13 @@
 
 #include "DXILShaderFlags.h"
 #include "DirectX.h"
+#include "llvm/ADT/SCCIterator.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Analysis/CallGraph.h"
 #include "llvm/Analysis/DXILResource.h"
 #include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsDirectX.h"
@@ -27,15 +31,24 @@
 using namespace llvm;
 using namespace llvm::dxil;
 
-static void updateFunctionFlags(ComputedShaderFlags &CSF, const Instruction &I,
-                                DXILResourceTypeMap &DRTM) {
+/// Update the shader flags mask based on the given instruction.
+/// \param CSF Shader flags mask to update.
+/// \param I Instruction to check.
+void ModuleShaderFlags::updateFunctionFlags(ComputedShaderFlags &CSF,
+                                            const Instruction &I,
+                                            DXILResourceTypeMap &DRTM) {
   if (!CSF.Doubles)
     CSF.Doubles = I.getType()->isDoubleTy();
 
   if (!CSF.Doubles) {
-    for (Value *Op : I.operands())
-      CSF.Doubles |= Op->getType()->isDoubleTy();
+    for (const Value *Op : I.operands()) {
+      if (Op->getType()->isDoubleTy()) {
+        CSF.Doubles = true;
+        break;
+      }
+    }
   }
+
   if (CSF.Doubles) {
     switch (I.getOpcode()) {
     case Instruction::FDiv:
@@ -43,8 +56,6 @@ static void updateFunctionFlags(ComputedShaderFlags &CSF, const Instruction &I,
     case Instruction::SIToFP:
     case Instruction::FPToUI:
     case Instruction::FPToSI:
-      // TODO: To be set if I is a call to DXIL intrinsic DXIL::Opcode::Fma
-      // https://github.com/llvm/llvm-project/issues/114554
       CSF.DX11_1_DoubleExtensions = true;
       break;
     }
@@ -62,27 +73,65 @@ static void updateFunctionFlags(ComputedShaderFlags &CSF, const Instruction &I,
     }
     }
   }
+  // Handle call instructions
+  if (auto *CI = dyn_cast<CallInst>(&I)) {
+    const Function *CF = CI->getCalledFunction();
+    // Merge-in shader flags mask of the called function in the current module
+    if (FunctionFlags.contains(CF)) {
+      CSF.merge(FunctionFlags[CF]);
+    }
+    // TODO: Set DX11_1_DoubleExtensions if I is a call to DXIL intrinsic
+    // DXIL::Opcode::Fma https://github.com/llvm/llvm-project/issues/114554
+  }
 }
 
-void ModuleShaderFlags::initialize(const Module &M, DXILResourceTypeMap &DRTM) {
-
-  // Collect shader flags for each of the functions
-  for (const auto &F : M.getFunctionList()) {
-    if (F.isDeclaration()) {
-      assert(!F.getName().starts_with("dx.op.") &&
-             "DXIL Shader Flag analysis should not be run post-lowering.");
-      continue;
+/// Construct ModuleShaderFlags for module Module M
+void ModuleShaderFlags::initialize(Module &M, DXILResourceTypeMap &DRTM) {
+  CallGraph CG(M);
+
+  // Compute Shader Flags Mask for all functions using post-order visit of SCC
+  // of the call graph.
+  for (scc_iterator<CallGraph *> SCCI = scc_begin(&CG); !SCCI.isAtEnd();
+       ++SCCI) {
+    const std::vector<CallGraphNode *> &CurSCC = *SCCI;
+
+    // Union of shader masks of all functions in CurSCC
+    ComputedShaderFlags SCCSF;
+    // List of functions in CurSCC that are neither external nor declarations
+    // and hence whose flags are collected
+    SmallVector<Function *> CurSCCFuncs;
+    for (CallGraphNode *CGN : CurSCC) {
+      Function *F = CGN->getFunction();
+      if (!F)
+        continue;
+
+      if (F->isDeclaration()) {
+        assert(!F->getName().starts_with("dx.op.") &&
+               "DXIL Shader Flag analysis should not be run post-lowering.");
+        continue;
+      }
+
+      ComputedShaderFlags CSF;
+      for (const auto &BB : *F)
+        for (const auto &I : BB)
+          updateFunctionFlags(CSF, I, DRTM);
+      // Update combined shader flags mask for all functions in this SCC
+      SCCSF.merge(CSF);
+
+      CurSCCFuncs.push_back(F);
     }
-    ComputedShaderFlags CSF;
-    for (const auto &BB : F)
-      for (const auto &I : BB)
-        updateFunctionFlags(CSF, I, DRTM);
-    // Insert shader flag mask for function F
-    FunctionFlags.push_back({&F, CSF});
-    // Update combined shader flags mask
-    CombinedSFMask.merge(CSF);
+
+    // Update combined shader flags mask for all functions of the module
+    CombinedSFMask.merge(SCCSF);
+
+    // Shader flags mask of each of the functions in an SCC of the call graph is
+    // the union of all functions in the SCC. Update shader flags masks of
+    // functions in CurSCC accordingly. This is trivially true if SCC contains
+    // one function.
+    for (Function *F : CurSCCFuncs)
+      // Merge SCCSF with that of F
+      FunctionFlags[F].merge(SCCSF);
   }
-  llvm::sort(FunctionFlags);
 }
 
 void ComputedShaderFlags::print(raw_ostream &OS) const {
@@ -106,12 +155,9 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
 /// Return the shader flags mask of the specified function Func.
 const ComputedShaderFlags &
 ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
-  const auto Iter = llvm::lower_bound(
-      FunctionFlags, Func,
-      [](const std::pair<const Function *, ComputedShaderFlags> FSM,
-         const Function *FindFunc) { return (FSM.first < FindFunc); });
+  auto Iter = FunctionFlags.find(Func);
   assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
-         "No Shader Flags Mask exists for function");
+         "Get Shader Flags : No Shader Flags Mask exists for function");
   return Iter->second;
 }
 
@@ -142,7 +188,7 @@ PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
   for (const auto &F : M.getFunctionList()) {
     if (F.isDeclaration())
       continue;
-    auto SFMask = FlagsInfo.getFunctionFlags(&F);
+    const ComputedShaderFlags &SFMask = FlagsInfo.getFunctionFlags(&F);
     OS << formatv("; Function {0} : {1:x8}\n;\n", F.getName(),
                   (uint64_t)(SFMask));
   }
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h
index 67ddab39d0f349..e6c6d56402c1a7 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.h
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h
@@ -71,13 +71,11 @@ struct ComputedShaderFlags {
     return FeatureFlags;
   }
 
-  void merge(const uint64_t IVal) {
+  void merge(const ComputedShaderFlags CSF) {
 #define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str)          \
-  FlagName |= (IVal & getMask(DxilModuleBit));
-#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str)                         \
-  FlagName |= (IVal & getMask(DxilModuleBit));
+  FlagName |= CSF.FlagName;
+#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) FlagName |= CSF.FlagName;
 #include "llvm/BinaryFormat/DXContainerConstants.def"
-    return;
   }
 
   void print(raw_ostream &OS = dbgs()) const;
@@ -85,17 +83,19 @@ struct ComputedShaderFlags {
 };
 
 struct ModuleShaderFlags {
-  void initialize(const Module &, DXILResourceTypeMap &DRTM);
+  void initialize(Module &, DXILResourceTypeMap &DRTM);
   const ComputedShaderFlags &getFunctionFlags(const Function *) const;
   const ComputedShaderFlags &getCombinedFlags() const { return CombinedSFMask; }
 
 private:
-  /// Vector of sorted Function-Shader Flag mask pairs representing properties
-  /// of each of the functions in the module. Shader Flags of each function
-  /// represent both module-level and function-level flags
-  SmallVector<std::pair<Function const *, ComputedShaderFlags>> FunctionFlags;
+  /// Map of Function-Shader Flag Mask pairs representing properties of each of
+  /// the functions in the module. Shader Flags of each function represent both
+  /// module-level and function-level flags
+  DenseMap<const Function *, ComputedShaderFlags> FunctionFlags;
   /// Combined Shader Flag Mask of all functions of the module
   ComputedShaderFlags CombinedSFMask{};
+  void updateFunctionFlags(ComputedShaderFlags &, const Instruction &,
+                           DXILResourceTypeMap &);
 };
 
 class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> {
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
index 6332ef806a0d8f..d6df67626be5aa 100644
--- a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
@@ -12,6 +12,13 @@ target triple = "dxil-pc-shadermodel6.7-library"
 ; CHECK-NEXT: ;
 ; CHECK-NEXT: ; Shader Flags for Module Functions
 
+;CHECK: ; Function top_level : 0x00000044
+define double @top_level() #0 {
+  %r = call double @test_uitofp_i64(i64 5)
+  ret double %r
+}
+
+
 ; CHECK: ; Function test_fdiv_double : 0x00000044
 define double @test_fdiv_double(double %a, double %b) #0 {
   %res = fdiv double %a, %b
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll
new file mode 100644
index 00000000000000..e7a2cf4d5b20f7
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll
@@ -0,0 +1,167 @@
+; RUN: opt -S --passes="print-dx-shader-flags" 2>&1 %s | FileCheck %s
+
+target triple = "dxil-pc-shadermodel6.7-library"
+
+; CHECK: ; Combined Shader Flags for Module
+; CHECK-NEXT: ; Shader Flags Value: 0x00000044
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Note: shader requires additional functionality:
+; CHECK-NEXT: ;       Double-precision floating point
+; CHECK-NEXT: ;       Double-precision extensions for 11.1
+; CHECK-NEXT: ; Note: extra DXIL module flags:
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Shader Flags for Module Functions
+
+; Call Graph of test source
+; main        -> [get_fptoui_flag, get_sitofp_fdiv_flag]
+; get_fptoui_flag  -> [get_sitofp_uitofp_flag, call_get_uitofp_flag]
+; get_sitofp_uitofp_flag  -> [call_get_fptoui_flag, call_get_sitofp_flag]
+; call_get_fptoui_flag  -> [get_fptoui_flag]
+; get_sitofp_fdiv_flag -> [get_no_flags, get_all_doubles_flags]
+; get_all_doubles_flags -> [call_get_sitofp_fdiv_flag]
+; call_get_sitofp_fdiv_flag  -> [get_sitofp_fdiv_flag]
+; call_get_sitofp_flag  -> [get_sitofp_flag]
+; call_get_uitofp_flag  -> [get_uitofp_flag]
+; get_sitofp_flag  -> []
+; get_uitofp_flag  -> []
+; get_no_flags  -> []
+;
+; Strongly Connected Component in the CG
+;  [get_fptoui_flag, get_sitofp_uitofp_flag, call_get_fptoui_flag]
+;  [get_sitofp_fdiv_flag, get_all_doubles_flags, call_get_sitofp_fdiv_flag]
+
+;
+; CHECK: ; Function get_sitofp_flag  : 0x00000044
+define double @get_sitofp_flag(i32 noundef %0) local_unnamed_addr #0 {
+  %2 = sitofp i32 %0 to double
+  ret double %2
+}
+
+; CHECK: ; Function call_get_sitofp_flag : 0x00000044
+define double @call_get_sitofp_flag(i32 noundef %0) local_unnamed_addr #0 {
+  %2 = tail call double @get_sitofp_flag(i32 noundef %0)
+  ret double %2
+}
+
+; CHECK: ; Function get_uitofp_flag : 0x00000044
+define double @get_uitofp_flag(i32 noundef %0) local_unnamed_addr #0 {
+  %2 = uitofp i32 %0 to double
+  ret double %2
+}
+
+; CHECK: ; Function call_get_uitofp_flag : 0x00000044
+define double @call_get_uitofp_flag(i32 noundef %0) local_unnamed_addr #0 {
+  %2 = tail call double @get_uitofp_flag(i32 noundef %0)
+  ret double %2
+}
+
+; CHECK: ; Function call_get_fptoui_flag : 0x00000044
+define double @call_get_fptoui_flag(double noundef %0) local_unnamed_addr #0 {
+  %2 = tail call double @get_fptoui_flag(double noundef %0)
+  ret double %2
+}
+
+; CHECK: ; Function get_fptoui_flag : 0x00000044
+define double @get_fptoui_flag(double noundef %0) local_unnamed_addr #0 {
+  %2 = fcmp ugt double %0, 5.000000e+00
+  br i1 %2, label %6, label %3
+
+3:                                                ; preds = %1
+  %4 = fptoui double %0 to i64
+  %5 = tail call double @get_sitofp_uitofp_flag(i64 noundef %4)
+  br label %9
+
+6:                                                ; preds = %1
+  %7 = fptoui double %0 to i32
+  %8 = tail call double @call_get_uitofp_flag(i32 noundef %7)
+  br label %9
+
+9:                                                ; preds = %6, %3
+  %10 = phi double [ %5, %3 ], [ %8, %6 ]
+  ret double %10
+}
+
+; CHECK: ; Function get_sitofp_uitofp_flag : 0x00000044
+define double @get_sitofp_uitofp_flag(i64 noundef %0) local_unnamed_addr #0 {
+  %2 = icmp ult i64 %0, 6
+  br i1 %2, label %3, label %7
+
+3:                                                ; preds = %1
+  %4 = add nuw nsw i64 %0, 1
+  %5 = uitofp i64 %4 to double
+  %6 = tail call double @call_get_fptoui_flag(double noundef %5)
+  br label %10
+
+7:                                                ; preds = %1
+  %8 = trunc i64 %0 to i32
+  %9 = tail call double @call_get_sitofp_flag(i32 noundef %8)
+  br label %10
+
+10:                                               ; preds = %7, %3
+  %11 = phi double [ %6, %3 ], [ %9, %7 ]
+  ret double %11
+}
+
+; CHECK: ; Function get_no_flags : 0x00000000
+define i32 @get_no_flags(i32 noundef %0) local_unnamed_addr #0 {
+  %2 = mul nsw i32 %0, %0
+  ret i32 %2
+}
+
+; CHECK: ; Function call_get_sitofp_fdiv_flag : 0x00000044
+define i32 @call_get_sitofp_fdiv_flag(i32 noundef %0) local_unnamed_addr #0 {
+  %2 = icmp eq i32 %0, 0
+  br i1 %2, label %5, label %3
+
+3:                                                ; preds = %1
+  %4 = mul nsw i32 %0, %0
+  br label %7
+
+5:                                                ; preds = %1
+  %6 = tail call double @get_sitofp_fdiv_flag(i32 noundef 0)
+  br label %7
+
+7:                                                ; preds = %5, %3
+  %8 = phi i32 [ %4, %3 ], [ 0, %5 ]
+  ret i32 %8
+}
+
+; CHECK: ; Function get_sitofp_fdiv_flag : 0x00000044
+define double @get_sitofp_fdiv_flag(i32 noundef %0) local_unnamed_addr #0 {
+  %2 = icmp sgt i32 %0, 5
+  br i1 %2, label %3, label %6
+
+3:                                                ; preds = %1
+  %4 = tail call i32 @get_no_flags(i32 noundef %0)
+  %5 = sitofp i32 %4 to double
+  br label %9
+
+6:                                                ; preds = %1
+  %7 = tail call double @get_all_doubles_flags(i32 noundef %0)
+  %8 = fdiv double %7, 3.000000e+00
+  br label %9
+
+9:                                                ; preds = %6, %3
+  %10 = phi double [ %5, %3 ], [ %8, %6 ]
+  ret double %10
+}
+
+; CHECK: ; Function get_all_doubles_flags : 0x00000044
+define double @get_all_doubles_flags(i32 noundef %0) local_unnamed_addr #0 {
+  %2 = tail call i32 @call_get_sitofp_fdiv_flag(i32 noundef %0)
+  %3 = icmp eq i32 %2, 0
+  %4 = select i1 %3, double 1.000000e+01, double 1.000000e+02
+  ret double %4
+}
+
+; CHECK: ; Function main : 0x00000044
+define i32 @main() local_unnamed_addr #0 {
+  %1 = tail call double @get_fptoui_flag(double noundef 1.000000e+00)
+  %2 = tail call double @get_sitofp_fdiv_flag(i32 noundef 4)
+  %3 = fadd double %1, %2
+  %4 = fcmp ogt double %3, 0.000000e+00
+  %5 = zext i1 %4 to i32
+  ret i32 %5
+}
+
+attributes #0 = { convergent norecurse nounwind "hlsl.export"}

>From e53cd26d830fe57c9760ad18144760265d82c72a Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Tue, 14 Jan 2025 11:10:00 -0500
Subject: [PATCH 2/3] Delete unnecessary #include

---
 llvm/lib/Target/DirectX/DXILShaderFlags.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index e956189f8ecd4e..4bcc01a90b1706 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -14,7 +14,6 @@
 #include "DXILShaderFlags.h"
 #include "DirectX.h"
 #include "llvm/ADT/SCCIterator.h"
-#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Analysis/CallGraph.h"
 #include "llvm/Analysis/DXILResource.h"

>From c1e134321ada333d908fb1f9c20b201d5b8798d6 Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Tue, 14 Jan 2025 12:10:47 -0500
Subject: [PATCH 3/3] Delete braces around single-statement if expression

---
 llvm/lib/Target/DirectX/DXILShaderFlags.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index 4bcc01a90b1706..b1ff975d4dae96 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -76,9 +76,9 @@ void ModuleShaderFlags::updateFunctionFlags(ComputedShaderFlags &CSF,
   if (auto *CI = dyn_cast<CallInst>(&I)) {
     const Function *CF = CI->getCalledFunction();
     // Merge-in shader flags mask of the called function in the current module
-    if (FunctionFlags.contains(CF)) {
+    if (FunctionFlags.contains(CF))
       CSF.merge(FunctionFlags[CF]);
-    }
+
     // TODO: Set DX11_1_DoubleExtensions if I is a call to DXIL intrinsic
     // DXIL::Opcode::Fma https://github.com/llvm/llvm-project/issues/114554
   }



More information about the llvm-commits mailing list