[llvm] [DirectX] Propagate shader flags mask of callees to callers (PR #118306)

S. Bharadwaj Yadavalli via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 6 12:52:34 PST 2024


================
@@ -46,21 +49,79 @@ static void updateFunctionFlags(ComputedShaderFlags &CSF,
   }
 }
 
-void ModuleShaderFlags::initialize(const Module &M) {
-  // Collect shader flags for each of the functions
-  for (const auto &F : M.getFunctionList()) {
-    if (F.isDeclaration())
+void ModuleShaderFlags::initialize(Module &M) {
+  CallGraph CG(M);
+
+  // Compute Shader Flags Mask for all functions using post-order visit of SCC
+  // of the call graph.
+  for (scc_iterator<CallGraph *> SCCI = scc_begin(&CG); !SCCI.isAtEnd();
+       ++SCCI) {
+    const std::vector<CallGraphNode *> &CurSCC = *SCCI;
+
+    // Union of shader masks of all functions in CurSCC
+    ComputedShaderFlags SCCSF;
+    for (CallGraphNode *CGN : CurSCC) {
+      Function *F = CGN->getFunction();
+      if (!F)
+        continue;
+
+      if (F->isDeclaration())
+        continue;
+
+      ComputedShaderFlags CSF;
+      for (const auto &BB : *F)
+        for (const auto &I : BB)
+          updateFunctionFlags(CSF, I);
+      // Insert shader flag mask for function F
+      FunctionFlags.insert({F, CSF});
+      // Update combined shader flags mask for all functions of the module
+      CombinedSFMask.merge(CSF);
+      // Update combined shader flags mask for all functions in this SCC
+      SCCSF.merge(CSF);
+    }
+
+    if (CurSCC.size() < 2)
       continue;
-    ComputedShaderFlags CSF;
-    for (const auto &BB : F)
-      for (const auto &I : BB)
-        updateFunctionFlags(CSF, I);
-    // Insert shader flag mask for function F
-    FunctionFlags.push_back({&F, CSF});
-    // Update combined shader flags mask
-    CombinedSFMask.merge(CSF);
+
+    // Shader flags mask of each of the functions in an SCC of the call graph is
+    // the union of all functions in the SCC. Update shader flags masks of
+    // functions in SCC accordingly.
+    for (CallGraphNode *CGN : CurSCC) {
+      Function *F = CGN->getFunction();
+      if (!F)
+        continue;
+      mergeFunctionShaderFlags(F, SCCSF);
----------------
bharadwajy wrote:

> If we defer inserting into the function flags map in the loop above and just calculate SCCSF there, and also remove the `CurSCC.size() < 2` check, then we can avoid the extra merge here and simply insert the function flags for the SCC for each function at this point.

Changes made to compute shader flags mask of SCC functions and propagate them to their callers into one loop.

https://github.com/llvm/llvm-project/pull/118306


More information about the llvm-commits mailing list