[llvm] [DirectX] Propagate shader flags mask of callees to callers (PR #118306)

via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 2 07:02:22 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-directx

Author: S. Bharadwaj Yadavalli (bharadwajy)

<details>
<summary>Changes</summary>

Propagate shader flags mask of callees to callers.

Add test to verify propagation of shader flags

---
Full diff: https://github.com/llvm/llvm-project/pull/118306.diff


4 Files Affected:

- (modified) llvm/lib/Target/DirectX/DXILShaderFlags.cpp (+39-4) 
- (modified) llvm/lib/Target/DirectX/DXILShaderFlags.h (+2) 
- (modified) llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll (+7) 
- (added) llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll (+92) 


``````````diff
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index d6917dce98abd5..f242204363cfe8 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -15,6 +15,7 @@
 #include "DirectX.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/IR/Module.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/raw_ostream.h"
@@ -47,10 +48,14 @@ static void updateFunctionFlags(ComputedShaderFlags &CSF,
 }
 
 void ModuleShaderFlags::initialize(const Module &M) {
+  SmallVector<const Function *> WorkList;
   // Collect shader flags for each of the functions
   for (const auto &F : M.getFunctionList()) {
     if (F.isDeclaration())
       continue;
+    if (!F.user_empty()) {
+      WorkList.push_back(&F);
+    }
     ComputedShaderFlags CSF;
     for (const auto &BB : F)
       for (const auto &I : BB)
@@ -61,6 +66,21 @@ void ModuleShaderFlags::initialize(const Module &M) {
     CombinedSFMask.merge(CSF);
   }
   llvm::sort(FunctionFlags);
+  // Propagate shader flag mask of functions to their callers.
+  while (!WorkList.empty()) {
+    const Function *Func = WorkList.pop_back_val();
+    if (!Func->user_empty()) {
+      ComputedShaderFlags FuncSF = getFunctionFlags(Func);
+      // Update mask of callers with that of Func
+      for (const auto User : Func->users()) {
+        if (const CallInst *CI = dyn_cast<CallInst>(User)) {
+          const Function *Caller = CI->getParent()->getParent();
+          if (mergeFunctionShaderFlags(Caller, FuncSF))
+            WorkList.push_back(Caller);
+        }
+      }
+    }
+  }
 }
 
 void ComputedShaderFlags::print(raw_ostream &OS) const {
@@ -81,16 +101,31 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
   OS << ";\n";
 }
 
-/// Return the shader flags mask of the specified function Func.
-const ComputedShaderFlags &
-ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
+auto ModuleShaderFlags::getFunctionShaderFlagInfo(const Function *Func) const {
   const auto Iter = llvm::lower_bound(
       FunctionFlags, Func,
       [](const std::pair<const Function *, ComputedShaderFlags> FSM,
          const Function *FindFunc) { return (FSM.first < FindFunc); });
   assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
          "No Shader Flags Mask exists for function");
-  return Iter->second;
+  return Iter;
+}
+
+/// Merge mask NewSF to that of Func, if different.
+/// Return true if mask of Func is changed, else false.
+bool ModuleShaderFlags::mergeFunctionShaderFlags(
+    const Function *Func, const ComputedShaderFlags NewSF) {
+  const auto FuncSFInfo = getFunctionShaderFlagInfo(Func);
+  if ((FuncSFInfo->second & NewSF) != NewSF) {
+    const_cast<ComputedShaderFlags &>(FuncSFInfo->second).merge(NewSF);
+    return true;
+  }
+  return false;
+}
+/// Return the shader flags mask of the specified function Func.
+const ComputedShaderFlags &
+ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
+  return getFunctionShaderFlagInfo(Func)->second;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h
index 2d60137f8b191c..8c581f243ca98b 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.h
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h
@@ -95,6 +95,8 @@ struct ModuleShaderFlags {
   SmallVector<std::pair<Function const *, ComputedShaderFlags>> FunctionFlags;
   /// Combined Shader Flag Mask of all functions of the module
   ComputedShaderFlags CombinedSFMask{};
+  auto getFunctionShaderFlagInfo(const Function *) const;
+  bool mergeFunctionShaderFlags(const Function *, ComputedShaderFlags);
 };
 
 class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> {
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
index 6332ef806a0d8f..8e5e61b42469ad 100644
--- a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
@@ -12,6 +12,13 @@ target triple = "dxil-pc-shadermodel6.7-library"
 ; CHECK-NEXT: ;
 ; CHECK-NEXT: ; Shader Flags for Module Functions
 
+;CHECK: ; Function top_level : 0x00000044
+define void @top_level() #0 {
+  call void @test_uitofp_i64(i64 noundef 5)
+  ret void
+}
+
+
 ; CHECK: ; Function test_fdiv_double : 0x00000044
 define double @test_fdiv_double(double %a, double %b) #0 {
   %res = fdiv double %a, %b
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll
new file mode 100644
index 00000000000000..93d634c0384ae7
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/propagate-function-flags-test.ll
@@ -0,0 +1,92 @@
+; RUN: opt -S --passes="print-dx-shader-flags" 2>&1 %s | FileCheck %s
+
+target triple = "dxil-pc-shadermodel6.7-library"
+
+; CHECK: ; Combined Shader Flags for Module
+; CHECK-NEXT: ; Shader Flags Value: 0x00000044
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Note: shader requires additional functionality:
+; CHECK-NEXT: ;       Double-precision floating point
+; CHECK-NEXT: ;       Double-precision extensions for 11.1
+; CHECK-NEXT: ; Note: extra DXIL module flags:
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Shader Flags for Module Functions
+
+; CHECK: ; Function call_n6 : 0x00000044
+define double @call_n6(i32 noundef %0) local_unnamed_addr #0 {
+  %2 = sitofp i32 %0 to double
+  ret double %2
+}
+; CHECK: ; Function call_n4 : 0x00000044
+define double @call_n4(i32 noundef %0) local_unnamed_addr #0 {
+  %2 = tail call double @call_n6(i32 noundef %0)
+  ret double %2
+}
+
+; CHECK: ; Function call_n7 : 0x00000044
+define double @call_n7(i32 noundef %0) local_unnamed_addr #0 {
+  %2 = uitofp i32 %0 to double
+  ret double %2
+}
+
+; CHECK: ; Function call_n5 : 0x00000044
+define double @call_n5(i32 noundef %0) local_unnamed_addr #0 {
+  %2 = tail call double @call_n7(i32 noundef %0)
+  ret double %2
+}
+
+; CHECK: ; Function call_n2 : 0x00000044
+define double @call_n2(i64 noundef %0) local_unnamed_addr #0 {
+  %2 = icmp ult i64 %0, 6
+  br i1 %2, label %3, label %7
+
+3:                                                ; preds = %1
+  %4 = add nuw nsw i64 %0, 1
+  %5 = uitofp i64 %4 to double
+  %6 = tail call double @call_n1(double noundef %5)
+  br label %10
+
+7:                                                ; preds = %1
+  %8 = trunc i64 %0 to i32
+  %9 = tail call double @call_n4(i32 noundef %8)
+  br label %10
+
+10:                                               ; preds = %7, %3
+  %11 = phi double [ %6, %3 ], [ %9, %7 ]
+  ret double %11
+}
+
+; CHECK: ; Function call_n1 : 0x00000044
+define double @call_n1(double noundef %0) local_unnamed_addr #0 {
+  %2 = fcmp ugt double %0, 5.000000e+00
+  br i1 %2, label %6, label %3
+
+3:                                                ; preds = %1
+  %4 = fptoui double %0 to i64
+  %5 = tail call double @call_n2(i64 noundef %4)
+  br label %9
+
+6:                                                ; preds = %1
+  %7 = fptoui double %0 to i32
+  %8 = tail call double @call_n5(i32 noundef %7)
+  br label %9
+
+9:                                                ; preds = %6, %3
+  %10 = phi double [ %5, %3 ], [ %8, %6 ]
+  ret double %10
+}
+
+; CHECK: ; Function call_n3 : 0x00000044
+define double @call_n3(double noundef %0) local_unnamed_addr #0 {
+  %2 = fdiv double %0, 3.000000e+00
+  ret double %2
+}
+
+; CHECK: ; Function main : 0x00000044
+define i32 @main() local_unnamed_addr #0 {
+  %1 = tail call double @call_n1(double noundef 1.000000e+00)
+  %2 = tail call double @call_n3(double noundef %1)
+  ret i32 0
+}
+
+attributes #0 = { convergent norecurse nounwind "hlsl.export"}

``````````

</details>


https://github.com/llvm/llvm-project/pull/118306


More information about the llvm-commits mailing list