[llvm] [DirectX] Propagate shader flags mask of callees to callers (PR #118306)

S. Bharadwaj Yadavalli via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 5 09:15:09 PST 2024


https://github.com/bharadwajy updated https://github.com/llvm/llvm-project/pull/118306

>From 6a4fef8b1c8ad78db63016ddc927bbbebeea04c5 Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Wed, 27 Nov 2024 11:25:34 -0500
Subject: [PATCH 1/4] Propagate shader flags mask of callees to callers

Add test to verify propagation of shader flags
---
 llvm/lib/Target/DirectX/DXILShaderFlags.cpp   | 43 ++++++++-
 llvm/lib/Target/DirectX/DXILShaderFlags.h     |  2 +
 .../DirectX/ShaderFlags/double-extensions.ll  |  7 ++
 .../propagate-function-flags-test.ll          | 92 +++++++++++++++++++
 4 files changed, 140 insertions(+), 4 deletions(-)
 create mode 100644 llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll

diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index d6917dce98abd5..f242204363cfe8 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -15,6 +15,7 @@
 #include "DirectX.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/IR/Module.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/raw_ostream.h"
@@ -47,10 +48,14 @@ static void updateFunctionFlags(ComputedShaderFlags &CSF,
 }
 
 void ModuleShaderFlags::initialize(const Module &M) {
+  SmallVector<const Function *> WorkList;
   // Collect shader flags for each of the functions
   for (const auto &F : M.getFunctionList()) {
     if (F.isDeclaration())
       continue;
+    if (!F.user_empty()) {
+      WorkList.push_back(&F);
+    }
     ComputedShaderFlags CSF;
     for (const auto &BB : F)
       for (const auto &I : BB)
@@ -61,6 +66,21 @@ void ModuleShaderFlags::initialize(const Module &M) {
     CombinedSFMask.merge(CSF);
   }
   llvm::sort(FunctionFlags);
+  // Propagate shader flag mask of functions to their callers.
+  while (!WorkList.empty()) {
+    const Function *Func = WorkList.pop_back_val();
+    if (!Func->user_empty()) {
+      ComputedShaderFlags FuncSF = getFunctionFlags(Func);
+      // Update mask of callers with that of Func
+      for (const auto User : Func->users()) {
+        if (const CallInst *CI = dyn_cast<CallInst>(User)) {
+          const Function *Caller = CI->getParent()->getParent();
+          if (mergeFunctionShaderFlags(Caller, FuncSF))
+            WorkList.push_back(Caller);
+        }
+      }
+    }
+  }
 }
 
 void ComputedShaderFlags::print(raw_ostream &OS) const {
@@ -81,16 +101,31 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
   OS << ";\n";
 }
 
-/// Return the shader flags mask of the specified function Func.
-const ComputedShaderFlags &
-ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
+auto ModuleShaderFlags::getFunctionShaderFlagInfo(const Function *Func) const {
   const auto Iter = llvm::lower_bound(
       FunctionFlags, Func,
       [](const std::pair<const Function *, ComputedShaderFlags> FSM,
          const Function *FindFunc) { return (FSM.first < FindFunc); });
   assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
          "No Shader Flags Mask exists for function");
-  return Iter->second;
+  return Iter;
+}
+
+/// Merge mask NewSF to that of Func, if different.
+/// Return true if mask of Func is changed, else false.
+bool ModuleShaderFlags::mergeFunctionShaderFlags(
+    const Function *Func, const ComputedShaderFlags NewSF) {
+  const auto FuncSFInfo = getFunctionShaderFlagInfo(Func);
+  if ((FuncSFInfo->second & NewSF) != NewSF) {
+    const_cast<ComputedShaderFlags &>(FuncSFInfo->second).merge(NewSF);
+    return true;
+  }
+  return false;
+}
+/// Return the shader flags mask of the specified function Func.
+const ComputedShaderFlags &
+ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
+  return getFunctionShaderFlagInfo(Func)->second;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h
index 2d60137f8b191c..8c581f243ca98b 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.h
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h
@@ -95,6 +95,8 @@ struct ModuleShaderFlags {
   SmallVector<std::pair<Function const *, ComputedShaderFlags>> FunctionFlags;
   /// Combined Shader Flag Mask of all functions of the module
   ComputedShaderFlags CombinedSFMask{};
+  auto getFunctionShaderFlagInfo(const Function *) const;
+  bool mergeFunctionShaderFlags(const Function *, ComputedShaderFlags);
 };
 
 class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> {
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
index 6332ef806a0d8f..8e5e61b42469ad 100644
--- a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
@@ -12,6 +12,13 @@ target triple = "dxil-pc-shadermodel6.7-library"
 ; CHECK-NEXT: ;
 ; CHECK-NEXT: ; Shader Flags for Module Functions
 
+;CHECK: ; Function top_level : 0x00000044
+define void @top_level() #0 {
+  call void @test_uitofp_i64(i64 noundef 5)
+  ret void
+}
+
+
 ; CHECK: ; Function test_fdiv_double : 0x00000044
 define double @test_fdiv_double(double %a, double %b) #0 {
   %res = fdiv double %a, %b
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll
new file mode 100644
index 00000000000000..93d634c0384ae7
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll
@@ -0,0 +1,92 @@
+; RUN: opt -S --passes="print-dx-shader-flags" 2>&1 %s | FileCheck %s
+
+target triple = "dxil-pc-shadermodel6.7-library"
+
+; CHECK: ; Combined Shader Flags for Module
+; CHECK-NEXT: ; Shader Flags Value: 0x00000044
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Note: shader requires additional functionality:
+; CHECK-NEXT: ;       Double-precision floating point
+; CHECK-NEXT: ;       Double-precision extensions for 11.1
+; CHECK-NEXT: ; Note: extra DXIL module flags:
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Shader Flags for Module Functions
+
+; CHECK: ; Function call_n6 : 0x00000044
+define double @call_n6(i32 noundef %0) local_unnamed_addr #0 {
+  %2 = sitofp i32 %0 to double
+  ret double %2
+}
+; CHECK: ; Function call_n4 : 0x00000044
+define double @call_n4(i32 noundef %0) local_unnamed_addr #0 {
+  %2 = tail call double @call_n6(i32 noundef %0)
+  ret double %2
+}
+
+; CHECK: ; Function call_n7 : 0x00000044
+define double @call_n7(i32 noundef %0) local_unnamed_addr #0 {
+  %2 = uitofp i32 %0 to double
+  ret double %2
+}
+
+; CHECK: ; Function call_n5 : 0x00000044
+define double @call_n5(i32 noundef %0) local_unnamed_addr #0 {
+  %2 = tail call double @call_n7(i32 noundef %0)
+  ret double %2
+}
+
+; CHECK: ; Function call_n2 : 0x00000044
+define double @call_n2(i64 noundef %0) local_unnamed_addr #0 {
+  %2 = icmp ult i64 %0, 6
+  br i1 %2, label %3, label %7
+
+3:                                                ; preds = %1
+  %4 = add nuw nsw i64 %0, 1
+  %5 = uitofp i64 %4 to double
+  %6 = tail call double @call_n1(double noundef %5)
+  br label %10
+
+7:                                                ; preds = %1
+  %8 = trunc i64 %0 to i32
+  %9 = tail call double @call_n4(i32 noundef %8)
+  br label %10
+
+10:                                               ; preds = %7, %3
+  %11 = phi double [ %6, %3 ], [ %9, %7 ]
+  ret double %11
+}
+
+; CHECK: ; Function call_n1 : 0x00000044
+define double @call_n1(double noundef %0) local_unnamed_addr #0 {
+  %2 = fcmp ugt double %0, 5.000000e+00
+  br i1 %2, label %6, label %3
+
+3:                                                ; preds = %1
+  %4 = fptoui double %0 to i64
+  %5 = tail call double @call_n2(i64 noundef %4)
+  br label %9
+
+6:                                                ; preds = %1
+  %7 = fptoui double %0 to i32
+  %8 = tail call double @call_n5(i32 noundef %7)
+  br label %9
+
+9:                                                ; preds = %6, %3
+  %10 = phi double [ %5, %3 ], [ %8, %6 ]
+  ret double %10
+}
+
+; CHECK: ; Function call_n3 : 0x00000044
+define double @call_n3(double noundef %0) local_unnamed_addr #0 {
+  %2 = fdiv double %0, 3.000000e+00
+  ret double %2
+}
+
+; CHECK: ; Function main : 0x00000044
+define i32 @main() local_unnamed_addr #0 {
+  %1 = tail call double @call_n1(double noundef 1.000000e+00)
+  %2 = tail call double @call_n3(double noundef %1)
+  ret i32 0
+}
+
+attributes #0 = { convergent norecurse nounwind "hlsl.export"}

>From a360efc5d0a90b4bf36384fa779df01c07bba1fd Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Tue, 3 Dec 2024 19:52:07 -0500
Subject: [PATCH 2/4] Address PR feedback

---
 llvm/lib/Target/DirectX/DXILShaderFlags.cpp | 48 +++++++++++----------
 llvm/lib/Target/DirectX/DXILShaderFlags.h   | 10 ++---
 2 files changed, 30 insertions(+), 28 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index f242204363cfe8..44fa5a2bb060b6 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -70,7 +70,7 @@ void ModuleShaderFlags::initialize(const Module &M) {
   while (!WorkList.empty()) {
     const Function *Func = WorkList.pop_back_val();
     if (!Func->user_empty()) {
-      ComputedShaderFlags FuncSF = getFunctionFlags(Func);
+      const ComputedShaderFlags &FuncSF = getFunctionFlags(Func);
       // Update mask of callers with that of Func
       for (const auto User : Func->users()) {
         if (const CallInst *CI = dyn_cast<CallInst>(User)) {
@@ -101,31 +101,35 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
   OS << ";\n";
 }
 
-auto ModuleShaderFlags::getFunctionShaderFlagInfo(const Function *Func) const {
-  const auto Iter = llvm::lower_bound(
-      FunctionFlags, Func,
-      [](const std::pair<const Function *, ComputedShaderFlags> FSM,
-         const Function *FindFunc) { return (FSM.first < FindFunc); });
-  assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
-         "No Shader Flags Mask exists for function");
-  return Iter;
+static bool compareShaderFlagsInfo(
+    const std::pair<const Function *, ComputedShaderFlags> FSM,
+    const Function *FindFunc) {
+  return (FSM.first < FindFunc);
 }
 
-/// Merge mask NewSF to that of Func, if different.
-/// Return true if mask of Func is changed, else false.
-bool ModuleShaderFlags::mergeFunctionShaderFlags(
-    const Function *Func, const ComputedShaderFlags NewSF) {
-  const auto FuncSFInfo = getFunctionShaderFlagInfo(Func);
-  if ((FuncSFInfo->second & NewSF) != NewSF) {
-    const_cast<ComputedShaderFlags &>(FuncSFInfo->second).merge(NewSF);
-    return true;
-  }
-  return false;
-}
 /// Return the shader flags mask of the specified function Func.
 const ComputedShaderFlags &
 ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
-  return getFunctionShaderFlagInfo(Func)->second;
+  const std::pair<const Function *, ComputedShaderFlags> *Iter =
+      llvm::lower_bound(FunctionFlags, Func, compareShaderFlagsInfo);
+  assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
+         "Get Shader Flags : No Shader Flags Mask exists for function");
+  return Iter->second;
+}
+
+/// Merge specified shader flags mask SF with current mask of the specified
+/// function Func.
+/// Return true if merge operation changes the value of shader flags mask of
+/// Func; else false.
+bool ModuleShaderFlags::mergeFunctionShaderFlags(const Function *Func,
+                                                 ComputedShaderFlags SF) {
+  std::pair<const Function *, ComputedShaderFlags> *Iter =
+      llvm::lower_bound(FunctionFlags, Func, compareShaderFlagsInfo);
+  assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
+         "Merge Shader Flags : No Shader Flags Mask exists for function");
+  ComputedShaderFlags PreMergeSF = Iter->second;
+  Iter->second.merge(SF);
+  return (PreMergeSF != Iter->second);
 }
 
 //===----------------------------------------------------------------------===//
@@ -152,7 +156,7 @@ PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
   for (const auto &F : M.getFunctionList()) {
     if (F.isDeclaration())
       continue;
-    auto SFMask = FlagsInfo.getFunctionFlags(&F);
+    const ComputedShaderFlags &SFMask = FlagsInfo.getFunctionFlags(&F);
     OS << formatv("; Function {0} : {1:x8}\n;\n", F.getName(),
                   (uint64_t)(SFMask));
   }
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h
index 8c581f243ca98b..158e4ea9033364 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.h
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h
@@ -70,11 +70,10 @@ struct ComputedShaderFlags {
     return FeatureFlags;
   }
 
-  void merge(const uint64_t IVal) {
+  void merge(const ComputedShaderFlags CSF) {
 #define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str)          \
-  FlagName |= (IVal & getMask(DxilModuleBit));
-#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str)                         \
-  FlagName |= (IVal & getMask(DxilModuleBit));
+  FlagName |= CSF.FlagName;
+#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) FlagName |= CSF.FlagName;
 #include "llvm/BinaryFormat/DXContainerConstants.def"
     return;
   }
@@ -92,10 +91,9 @@ struct ModuleShaderFlags {
   /// Vector of sorted Function-Shader Flag mask pairs representing properties
   /// of each of the functions in the module. Shader Flags of each function
   /// represent both module-level and function-level flags
-  SmallVector<std::pair<Function const *, ComputedShaderFlags>> FunctionFlags;
+  SmallVector<std::pair<const Function *, ComputedShaderFlags>> FunctionFlags;
   /// Combined Shader Flag Mask of all functions of the module
   ComputedShaderFlags CombinedSFMask{};
-  auto getFunctionShaderFlagInfo(const Function *) const;
   bool mergeFunctionShaderFlags(const Function *, ComputedShaderFlags);
 };
 

>From b7cb3cd12a5816f55037e55cb981e10cdc462fbf Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Wed, 4 Dec 2024 15:07:11 -0500
Subject: [PATCH 3/4] Address PR feedback - Use DenseMap instead of SmallVector
 for Function-Shader Flags Mask map - Return true if
 ComputedShaderFlags::merge() if value is modified - Style nits

---
 llvm/lib/Target/DirectX/DXILShaderFlags.cpp | 48 +++++++++------------
 llvm/lib/Target/DirectX/DXILShaderFlags.h   | 20 ++++++---
 2 files changed, 33 insertions(+), 35 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index 44fa5a2bb060b6..3e1009124ceeca 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -53,31 +53,33 @@ void ModuleShaderFlags::initialize(const Module &M) {
   for (const auto &F : M.getFunctionList()) {
     if (F.isDeclaration())
       continue;
-    if (!F.user_empty()) {
+
+    if (!F.user_empty())
       WorkList.push_back(&F);
-    }
+
     ComputedShaderFlags CSF;
     for (const auto &BB : F)
       for (const auto &I : BB)
         updateFunctionFlags(CSF, I);
     // Insert shader flag mask for function F
-    FunctionFlags.push_back({&F, CSF});
+    FunctionFlags.insert({&F, CSF});
     // Update combined shader flags mask
     CombinedSFMask.merge(CSF);
   }
-  llvm::sort(FunctionFlags);
+
   // Propagate shader flag mask of functions to their callers.
   while (!WorkList.empty()) {
     const Function *Func = WorkList.pop_back_val();
-    if (!Func->user_empty()) {
-      const ComputedShaderFlags &FuncSF = getFunctionFlags(Func);
-      // Update mask of callers with that of Func
-      for (const auto User : Func->users()) {
-        if (const CallInst *CI = dyn_cast<CallInst>(User)) {
-          const Function *Caller = CI->getParent()->getParent();
-          if (mergeFunctionShaderFlags(Caller, FuncSF))
-            WorkList.push_back(Caller);
-        }
+    if (Func->user_empty())
+      continue;
+
+    const ComputedShaderFlags &FuncSF = getFunctionFlags(Func);
+    // Update mask of callers with that of Func
+    for (const auto User : Func->users()) {
+      if (const CallInst *CI = dyn_cast<CallInst>(User)) {
+        const Function *Caller = CI->getParent()->getParent();
+        if (mergeFunctionShaderFlags(Caller, FuncSF))
+          WorkList.push_back(Caller);
       }
     }
   }
@@ -101,17 +103,10 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
   OS << ";\n";
 }
 
-static bool compareShaderFlagsInfo(
-    const std::pair<const Function *, ComputedShaderFlags> FSM,
-    const Function *FindFunc) {
-  return (FSM.first < FindFunc);
-}
-
 /// Return the shader flags mask of the specified function Func.
 const ComputedShaderFlags &
 ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
-  const std::pair<const Function *, ComputedShaderFlags> *Iter =
-      llvm::lower_bound(FunctionFlags, Func, compareShaderFlagsInfo);
+  auto Iter = FunctionFlags.find(Func);
   assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
          "Get Shader Flags : No Shader Flags Mask exists for function");
   return Iter->second;
@@ -119,17 +114,14 @@ ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
 
 /// Merge specified shader flags mask SF with current mask of the specified
 /// function Func.
-/// Return true if merge operation changes the value of shader flags mask of
-/// Func; else false.
+/// Return merge result status viz., true if merge operation changes the value
+/// of shader flags mask of Func; else false.
 bool ModuleShaderFlags::mergeFunctionShaderFlags(const Function *Func,
                                                  ComputedShaderFlags SF) {
-  std::pair<const Function *, ComputedShaderFlags> *Iter =
-      llvm::lower_bound(FunctionFlags, Func, compareShaderFlagsInfo);
+  auto Iter = FunctionFlags.find(Func);
   assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
          "Merge Shader Flags : No Shader Flags Mask exists for function");
-  ComputedShaderFlags PreMergeSF = Iter->second;
-  Iter->second.merge(SF);
-  return (PreMergeSF != Iter->second);
+  return Iter->second.merge(SF);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h
index 158e4ea9033364..e3a2a26367c7da 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.h
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h
@@ -70,12 +70,18 @@ struct ComputedShaderFlags {
     return FeatureFlags;
   }
 
-  void merge(const ComputedShaderFlags CSF) {
+  /// Return merge result status viz., true if merge operation changes
+  /// shader flags mask; else false.
+  bool merge(const ComputedShaderFlags CSF) {
+    bool Changed = false;
 #define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str)          \
+  Changed |= (FlagName ^ CSF.FlagName);                                        \
+  FlagName |= CSF.FlagName;
+#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str)                         \
+  Changed |= (FlagName ^ CSF.FlagName);                                        \
   FlagName |= CSF.FlagName;
-#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) FlagName |= CSF.FlagName;
 #include "llvm/BinaryFormat/DXContainerConstants.def"
-    return;
+    return Changed;
   }
 
   void print(raw_ostream &OS = dbgs()) const;
@@ -88,10 +94,10 @@ struct ModuleShaderFlags {
   const ComputedShaderFlags &getCombinedFlags() const { return CombinedSFMask; }
 
 private:
-  /// Vector of sorted Function-Shader Flag mask pairs representing properties
-  /// of each of the functions in the module. Shader Flags of each function
-  /// represent both module-level and function-level flags
-  SmallVector<std::pair<const Function *, ComputedShaderFlags>> FunctionFlags;
+  /// Map of Function-Shader Flag Mask pairs representing properties of each of
+  /// the functions in the module. Shader Flags of each function represent both
+  /// module-level and function-level flags
+  DenseMap<const Function *, ComputedShaderFlags> FunctionFlags;
   /// Combined Shader Flag Mask of all functions of the module
   ComputedShaderFlags CombinedSFMask{};
   bool mergeFunctionShaderFlags(const Function *, ComputedShaderFlags);

>From 0e8e937d40ef76bc3b760581b1c0447f73f648fc Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Thu, 5 Dec 2024 11:55:11 -0500
Subject: [PATCH 4/4] Address PR feedbasck: Use CallGraph and scc_iterator to
 discover functions for shader flag computation and propagation.

ComputedShaderFlags::merge() modified to not track and return
changed status as it does not need to provide such functionality.

Update test.
---
 llvm/lib/Target/DirectX/DXILShaderFlags.cpp   | 103 ++++++++++++------
 llvm/lib/Target/DirectX/DXILShaderFlags.h     |  15 +--
 .../propagate-function-flags-test.ll          |   8 +-
 3 files changed, 81 insertions(+), 45 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index 3e1009124ceeca..19cbb35bc76f31 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -13,7 +13,9 @@
 
 #include "DXILShaderFlags.h"
 #include "DirectX.h"
+#include "llvm/ADT/SCCIterator.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/Analysis/CallGraph.h"
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Module.h"
@@ -47,39 +49,76 @@ static void updateFunctionFlags(ComputedShaderFlags &CSF,
   }
 }
 
-void ModuleShaderFlags::initialize(const Module &M) {
-  SmallVector<const Function *> WorkList;
-  // Collect shader flags for each of the functions
-  for (const auto &F : M.getFunctionList()) {
-    if (F.isDeclaration())
+void ModuleShaderFlags::initialize(Module &M) {
+  CallGraph CG(M);
+
+  // Compute Shader Flags Mask for all functions using post-order visit of SCC
+  // of the call graph.
+  for (scc_iterator<CallGraph *> SCCI = scc_begin(&CG); !SCCI.isAtEnd();
+       ++SCCI) {
+    const std::vector<CallGraphNode *> &CurSCC = *SCCI;
+
+    // Union of shader masks of all functions in CurSCC
+    ComputedShaderFlags SCCSF;
+    for (CallGraphNode *CGN : CurSCC) {
+      Function *F = CGN->getFunction();
+      if (!F)
+        continue;
+
+      if (F->isDeclaration())
+        continue;
+
+      ComputedShaderFlags CSF;
+      for (const auto &BB : *F)
+        for (const auto &I : BB)
+          updateFunctionFlags(CSF, I);
+      // Insert shader flag mask for function F
+      FunctionFlags.insert({F, CSF});
+      // Update combined shader flags mask for all functions of the module
+      CombinedSFMask.merge(CSF);
+      // Update combined shader flags mask for all functions in this SCC
+      SCCSF.merge(CSF);
+    }
+
+    if (CurSCC.size() < 2)
       continue;
 
-    if (!F.user_empty())
-      WorkList.push_back(&F);
-
-    ComputedShaderFlags CSF;
-    for (const auto &BB : F)
-      for (const auto &I : BB)
-        updateFunctionFlags(CSF, I);
-    // Insert shader flag mask for function F
-    FunctionFlags.insert({&F, CSF});
-    // Update combined shader flags mask
-    CombinedSFMask.merge(CSF);
+    // Shader flags mask of each of the functions in an SCC of the call graph is
+    // the union of all functions in the SCC. Update shader flags masks of
+    // functions in SCC accordingly.
+    for (CallGraphNode *CGN : CurSCC) {
+      Function *F = CGN->getFunction();
+      if (!F)
+        continue;
+      mergeFunctionShaderFlags(F, SCCSF);
+    }
   }
 
-  // Propagate shader flag mask of functions to their callers.
-  while (!WorkList.empty()) {
-    const Function *Func = WorkList.pop_back_val();
-    if (Func->user_empty())
-      continue;
-
-    const ComputedShaderFlags &FuncSF = getFunctionFlags(Func);
-    // Update mask of callers with that of Func
-    for (const auto User : Func->users()) {
-      if (const CallInst *CI = dyn_cast<CallInst>(User)) {
-        const Function *Caller = CI->getParent()->getParent();
-        if (mergeFunctionShaderFlags(Caller, FuncSF))
-          WorkList.push_back(Caller);
+  // Propagate Shader Flag Masks to callers with another post-order call graph
+  // walk
+  for (scc_iterator<CallGraph *> SCCI = scc_begin(&CG); !SCCI.isAtEnd();
+       ++SCCI) {
+    const std::vector<CallGraphNode *> &CurSCC = *SCCI;
+    for (CallGraphNode *CGN : CurSCC) {
+      Function *F = CGN->getFunction();
+      if (!F)
+        continue;
+
+      if (F->isDeclaration() || F->user_empty())
+        continue;
+
+      const ComputedShaderFlags &FuncSF = getFunctionFlags(F);
+      // Update mask of callers with that of Func
+      for (const auto User : F->users()) {
+        if (const CallInst *CI = dyn_cast<CallInst>(User)) {
+          const Function *Caller = CI->getParent()->getParent();
+          // Do not need to update masks of callers in the current
+          // SCC, as the masks of all functions in the SCC are alreday
+          // the same. However, it is simpler to merge unconditionally
+          // instead of searching for membership of each Caller in the
+          // vector CurSCC to avoid merging.
+          mergeFunctionShaderFlags(Caller, FuncSF);
+        }
       }
     }
   }
@@ -114,14 +153,12 @@ ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
 
 /// Merge specified shader flags mask SF with current mask of the specified
 /// function Func.
-/// Return merge result status viz., true if merge operation changes the value
-/// of shader flags mask of Func; else false.
-bool ModuleShaderFlags::mergeFunctionShaderFlags(const Function *Func,
+void ModuleShaderFlags::mergeFunctionShaderFlags(const Function *Func,
                                                  ComputedShaderFlags SF) {
   auto Iter = FunctionFlags.find(Func);
   assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
          "Merge Shader Flags : No Shader Flags Mask exists for function");
-  return Iter->second.merge(SF);
+  Iter->second.merge(SF);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h
index e3a2a26367c7da..b7837679da016b 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.h
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h
@@ -70,18 +70,11 @@ struct ComputedShaderFlags {
     return FeatureFlags;
   }
 
-  /// Return merge result status viz., true if merge operation changes
-  /// shader flags mask; else false.
-  bool merge(const ComputedShaderFlags CSF) {
-    bool Changed = false;
+  void merge(const ComputedShaderFlags CSF) {
 #define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str)          \
-  Changed |= (FlagName ^ CSF.FlagName);                                        \
-  FlagName |= CSF.FlagName;
-#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str)                         \
-  Changed |= (FlagName ^ CSF.FlagName);                                        \
   FlagName |= CSF.FlagName;
+#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) FlagName |= CSF.FlagName;
 #include "llvm/BinaryFormat/DXContainerConstants.def"
-    return Changed;
   }
 
   void print(raw_ostream &OS = dbgs()) const;
@@ -89,7 +82,7 @@ struct ComputedShaderFlags {
 };
 
 struct ModuleShaderFlags {
-  void initialize(const Module &);
+  void initialize(Module &);
   const ComputedShaderFlags &getFunctionFlags(const Function *) const;
   const ComputedShaderFlags &getCombinedFlags() const { return CombinedSFMask; }
 
@@ -100,7 +93,7 @@ struct ModuleShaderFlags {
   DenseMap<const Function *, ComputedShaderFlags> FunctionFlags;
   /// Combined Shader Flag Mask of all functions of the module
   ComputedShaderFlags CombinedSFMask{};
-  bool mergeFunctionShaderFlags(const Function *, ComputedShaderFlags);
+  void mergeFunctionShaderFlags(const Function *, ComputedShaderFlags);
 };
 
 class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> {
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll
index 93d634c0384ae7..914ca021e95cd9 100644
--- a/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll
@@ -43,7 +43,7 @@ define double @call_n2(i64 noundef %0) local_unnamed_addr #0 {
 3:                                                ; preds = %1
   %4 = add nuw nsw i64 %0, 1
   %5 = uitofp i64 %4 to double
-  %6 = tail call double @call_n1(double noundef %5)
+  %6 = tail call double @call_n11(double noundef %5)
   br label %10
 
 7:                                                ; preds = %1
@@ -56,6 +56,12 @@ define double @call_n2(i64 noundef %0) local_unnamed_addr #0 {
   ret double %11
 }
 
+; CHECK: ; Function call_n11 : 0x00000044
+define double @call_n11(double noundef %0) local_unnamed_addr #1 {
+  %2 = tail call double @call_n1(double noundef %0)
+  ret double %2
+}
+
 ; CHECK: ; Function call_n1 : 0x00000044
 define double @call_n1(double noundef %0) local_unnamed_addr #0 {
   %2 = fcmp ugt double %0, 5.000000e+00



More information about the llvm-commits mailing list