[llvm] [DirectX] Infrastructure to collect shader flags for each function (PR #112967)

Damyan Pepper via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 5 16:57:27 PST 2024


================
@@ -13,36 +13,87 @@
 
 #include "DXILShaderFlags.h"
 #include "DirectX.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/IR/DiagnosticInfo.h"
+#include "llvm/IR/DiagnosticPrinter.h"
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Module.h"
+#include "llvm/Support/Error.h"
 #include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/raw_ostream.h"
 
 using namespace llvm;
 using namespace llvm::dxil;
 
-static void updateFlags(ComputedShaderFlags &Flags, const Instruction &I) {
-  Type *Ty = I.getType();
-  if (Ty->isDoubleTy()) {
-    Flags.Doubles = true;
+namespace {
+/// A simple Wrapper DiagnosticInfo that generates Module-level diagnostic
+/// for Shader Flags Analysis pass
+class DiagnosticInfoShaderFlags : public DiagnosticInfo {
+private:
+  const Twine &Msg;
+  const Module &Mod;
+
+public:
+  /// \p M is the module for which the diagnostic is being emitted. \p Msg is
+  /// the message to show. Note that this class does not copy this message, so
+  /// this reference must be valid for the whole life time of the diagnostic.
+  DiagnosticInfoShaderFlags(const Module &M, const Twine &Msg,
+                            DiagnosticSeverity Severity = DS_Error)
+      : DiagnosticInfo(DK_Unsupported, Severity), Msg(Msg), Mod(M) {}
+
+  void print(DiagnosticPrinter &DP) const override {
+    DP << Mod.getName() << ": " << Msg << '\n';
+  }
+};
+} // namespace
+
+static void updateFlags(ComputedShaderFlags &CSF, const Instruction &I) {
+  if (!CSF.Doubles) {
+    CSF.Doubles = I.getType()->isDoubleTy();
+  }
+  if (!CSF.Doubles) {
+    for (Value *Op : I.operands()) {
+      CSF.Doubles |= Op->getType()->isDoubleTy();
+    }
+  }
+  if (CSF.Doubles) {
     switch (I.getOpcode()) {
     case Instruction::FDiv:
     case Instruction::UIToFP:
     case Instruction::SIToFP:
     case Instruction::FPToUI:
     case Instruction::FPToSI:
-      Flags.DX11_1_DoubleExtensions = true;
+      // TODO: To be set if I is a call to DXIL intrinsic DXIL::Opcode::Fma
+      CSF.DX11_1_DoubleExtensions = true;
       break;
     }
   }
 }
 
-ComputedShaderFlags ComputedShaderFlags::computeFlags(Module &M) {
-  ComputedShaderFlags Flags;
-  for (const auto &F : M)
-    for (const auto &BB : F)
+static DXILModuleShaderFlagsInfo computeFlags(const Module &M) {
+  // Construct a sorted list of functions in the module. Walk the sorted list to
+  // create a list of <Function, Shader Flags Mask> pairs. This list is thus
+  // sorted at construction time and may be looked up using binary search.
----------------
damyanp wrote:

This really seems quite a convoluted way of doing this.  Why not just build up the list of functions/flags in one go? Something like:

```c++
{
    SmallVector<std::pair<const Function *, ComputedShaderFlags>> FuncList;
    for (const auto &F : M.getFunctionList()) {
        if (F.isDeclaration())
            continue;
        
        ComputedShaderFlags CSF{};
        for (const auto &BB : F)
            for (const auto &I: BB)
                updateFlags(CSF, I);
        
        FuncList.push_back({&F, CSF});        
    }
    llvm::sort(FuncList);

    return DXILModuleShaderFlagsInfo(std::move(FuncList));
}
```




https://github.com/llvm/llvm-project/pull/112967


More information about the llvm-commits mailing list