[llvm] [DirectX] Propagate shader flags mask of callees to callers (PR #118306)

Chris B via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 14 08:44:54 PST 2025


================
@@ -62,27 +73,65 @@ static void updateFunctionFlags(ComputedShaderFlags &CSF, const Instruction &I,
     }
     }
   }
+  // Handle call instructions
+  if (auto *CI = dyn_cast<CallInst>(&I)) {
+    const Function *CF = CI->getCalledFunction();
+    // Merge-in shader flags mask of the called function in the current module
+    if (FunctionFlags.contains(CF)) {
+      CSF.merge(FunctionFlags[CF]);
+    }
+    // TODO: Set DX11_1_DoubleExtensions if I is a call to DXIL intrinsic
+    // DXIL::Opcode::Fma https://github.com/llvm/llvm-project/issues/114554
+  }
 }
 
-void ModuleShaderFlags::initialize(const Module &M, DXILResourceTypeMap &DRTM) {
-
-  // Collect shader flags for each of the functions
-  for (const auto &F : M.getFunctionList()) {
-    if (F.isDeclaration()) {
-      assert(!F.getName().starts_with("dx.op.") &&
-             "DXIL Shader Flag analysis should not be run post-lowering.");
-      continue;
+/// Construct ModuleShaderFlags for module Module M
+void ModuleShaderFlags::initialize(Module &M, DXILResourceTypeMap &DRTM) {
+  CallGraph CG(M);
+
+  // Compute Shader Flags Mask for all functions using post-order visit of SCC
+  // of the call graph.
+  for (scc_iterator<CallGraph *> SCCI = scc_begin(&CG); !SCCI.isAtEnd();
+       ++SCCI) {
+    const std::vector<CallGraphNode *> &CurSCC = *SCCI;
+
+    // Union of shader masks of all functions in CurSCC
+    ComputedShaderFlags SCCSF;
+    // List of functions in CurSCC that are neither external nor declarations
+    // and hence whose flags are collected
+    SmallVector<Function *> CurSCCFuncs;
+    for (CallGraphNode *CGN : CurSCC) {
+      Function *F = CGN->getFunction();
+      if (!F)
+        continue;
+
+      if (F->isDeclaration()) {
+        assert(!F->getName().starts_with("dx.op.") &&
----------------
llvm-beanz wrote:

Should this be `report_fatal_error` instead of an assert?

https://github.com/llvm/llvm-project/pull/118306


More information about the llvm-commits mailing list