[llvm] [DirectX] Infrastructure to collect shader flags for each function (PR #112967)

S. Bharadwaj Yadavalli via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 28 13:46:53 PDT 2024


https://github.com/bharadwajy updated https://github.com/llvm/llvm-project/pull/112967

>From 3da01ee7ee64c099344661d48e14960e04d19cfc Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Wed, 16 Oct 2024 14:57:57 -0400
Subject: [PATCH 1/4] [NFC][DirectX] Infrastructure to collect shader flags for
 each function

Currently, ShaderFlagsAnalysis pass represents various module-level
properties as well as function-level properties of a DXIL Module
using a single mask. However, separate flags to represent module-level
properties and function-level properties are needed for accurate computation
of shader flags mask, such as for entry function metadata creation.

This change introduces a structure that allows separate representation of

(a) shader flag mask to represent module properties
(b) a map of function to shader flag mask that represent function properties

instead of a single shader flag mask that represents module properties
and properties of all function. The result type of ShaderFlagsAnalysis
pass is changed to newly-defined structure type instead of a single shader
flags mask.

This seperation allows accurate computation of shader flags of an entry
function for use during its metadata generation (DXILTranslateMetadata pass)
and its feature flags in DX container globals construction (DXContainerGlobals
pass) based on the shader flags mask of functions called in entry function.
However, note that the change to implement such callee-based shader flags mask
computation is planned in a follow-on PR. Consequently, this PR changes shader
flag mask computation in DXILTranslateMetadata and DXContainerGlobals passes
to simply be a union of module flags and shader flags of all functions, thereby
retaining the existing effect of using a single shader flag mask.
---
 .../lib/Target/DirectX/DXContainerGlobals.cpp | 15 ++++--
 llvm/lib/Target/DirectX/DXILShaderFlags.cpp   | 46 +++++++++++++------
 llvm/lib/Target/DirectX/DXILShaderFlags.h     | 26 +++++++----
 .../Target/DirectX/DXILTranslateMetadata.cpp  | 46 +++++++++++--------
 4 files changed, 87 insertions(+), 46 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
index 2c11373504e8c7..c7202cc04c26dc 100644
--- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
+++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
@@ -78,13 +78,18 @@ bool DXContainerGlobals::runOnModule(Module &M) {
 }
 
 GlobalVariable *DXContainerGlobals::getFeatureFlags(Module &M) {
-  const uint64_t FeatureFlags =
-      static_cast<uint64_t>(getAnalysis<ShaderFlagsAnalysisWrapper>()
-                                .getShaderFlags()
-                                .getFeatureFlags());
+  const DXILModuleShaderFlagsInfo &MSFI =
+      getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags();
+  // TODO: Feature flags mask is obtained as a collection of feature flags
+  // of the shader flags of all functions in the module. Need to verify
+  // and modify the computation of feature flags to be used.
+  uint64_t ConsolidatedFeatureFlags = 0;
+  for (const auto &FuncFlags : MSFI.FuncShaderFlagsMap) {
+    ConsolidatedFeatureFlags |= FuncFlags.second.getFeatureFlags();
+  }
 
   Constant *FeatureFlagsConstant =
-      ConstantInt::get(M.getContext(), APInt(64, FeatureFlags));
+      ConstantInt::get(M.getContext(), APInt(64, ConsolidatedFeatureFlags));
   return buildContainerGlobal(M, FeatureFlagsConstant, "dx.sfi0", "SFI0");
 }
 
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index 9fa137b4c025e1..8c590862008862 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -20,33 +20,41 @@
 using namespace llvm;
 using namespace llvm::dxil;
 
-static void updateFlags(ComputedShaderFlags &Flags, const Instruction &I) {
+static void updateFlags(DXILModuleShaderFlagsInfo &MSFI, const Instruction &I) {
+  ComputedShaderFlags &FSF = MSFI.FuncShaderFlagsMap[I.getFunction()];
   Type *Ty = I.getType();
   if (Ty->isDoubleTy()) {
-    Flags.Doubles = true;
+    FSF.Doubles = true;
     switch (I.getOpcode()) {
     case Instruction::FDiv:
     case Instruction::UIToFP:
     case Instruction::SIToFP:
     case Instruction::FPToUI:
     case Instruction::FPToSI:
-      Flags.DX11_1_DoubleExtensions = true;
+      FSF.DX11_1_DoubleExtensions = true;
       break;
     }
   }
 }
 
-ComputedShaderFlags ComputedShaderFlags::computeFlags(Module &M) {
-  ComputedShaderFlags Flags;
-  for (const auto &F : M)
+static DXILModuleShaderFlagsInfo computeFlags(Module &M) {
+  DXILModuleShaderFlagsInfo MSFI;
+  for (const auto &F : M) {
+    if (F.isDeclaration())
+      continue;
+    if (!MSFI.FuncShaderFlagsMap.contains(&F)) {
+      ComputedShaderFlags CSF{};
+      MSFI.FuncShaderFlagsMap[&F] = CSF;
+    }
     for (const auto &BB : F)
       for (const auto &I : BB)
-        updateFlags(Flags, I);
-  return Flags;
+        updateFlags(MSFI, I);
+  }
+  return MSFI;
 }
 
 void ComputedShaderFlags::print(raw_ostream &OS) const {
-  uint64_t FlagVal = (uint64_t) * this;
+  uint64_t FlagVal = (uint64_t)*this;
   OS << formatv("; Shader Flags Value: {0:x8}\n;\n", FlagVal);
   if (FlagVal == 0)
     return;
@@ -65,15 +73,25 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
 
 AnalysisKey ShaderFlagsAnalysis::Key;
 
-ComputedShaderFlags ShaderFlagsAnalysis::run(Module &M,
-                                             ModuleAnalysisManager &AM) {
-  return ComputedShaderFlags::computeFlags(M);
+DXILModuleShaderFlagsInfo ShaderFlagsAnalysis::run(Module &M,
+                                                   ModuleAnalysisManager &AM) {
+  return computeFlags(M);
+}
+
+bool ShaderFlagsAnalysisWrapper::runOnModule(Module &M) {
+  MSFI = computeFlags(M);
+  return false;
 }
 
 PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
                                                   ModuleAnalysisManager &AM) {
-  ComputedShaderFlags Flags = AM.getResult<ShaderFlagsAnalysis>(M);
-  Flags.print(OS);
+  DXILModuleShaderFlagsInfo Flags = AM.getResult<ShaderFlagsAnalysis>(M);
+  OS << "; Shader Flags mask for Module:\n";
+  Flags.ModuleFlags.print(OS);
+  for (auto SF : Flags.FuncShaderFlagsMap) {
+    OS << "; Shader Flags mash for Function: " << SF.first->getName() << "\n";
+    SF.second.print(OS);
+  }
   return PreservedAnalyses::all();
 }
 
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h
index 1df7d27de13d3c..6f81ff74384d0c 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.h
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h
@@ -14,6 +14,8 @@
 #ifndef LLVM_TARGET_DIRECTX_DXILSHADERFLAGS_H
 #define LLVM_TARGET_DIRECTX_DXILSHADERFLAGS_H
 
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/IR/Function.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/Compiler.h"
@@ -60,11 +62,20 @@ struct ComputedShaderFlags {
     return FeatureFlags;
   }
 
-  static ComputedShaderFlags computeFlags(Module &M);
   void print(raw_ostream &OS = dbgs()) const;
   LLVM_DUMP_METHOD void dump() const { print(); }
 };
 
+using FunctionShaderFlagsMap =
+    SmallDenseMap<Function const *, ComputedShaderFlags>;
+struct DXILModuleShaderFlagsInfo {
+  // Shader Flag mask representing module-level properties
+  ComputedShaderFlags ModuleFlags;
+  // Map representing shader flag mask representing properties of each of the
+  // functions in the module
+  FunctionShaderFlagsMap FuncShaderFlagsMap;
+};
+
 class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> {
   friend AnalysisInfoMixin<ShaderFlagsAnalysis>;
   static AnalysisKey Key;
@@ -72,9 +83,9 @@ class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> {
 public:
   ShaderFlagsAnalysis() = default;
 
-  using Result = ComputedShaderFlags;
+  using Result = DXILModuleShaderFlagsInfo;
 
-  ComputedShaderFlags run(Module &M, ModuleAnalysisManager &AM);
+  DXILModuleShaderFlagsInfo run(Module &M, ModuleAnalysisManager &AM);
 };
 
 /// Printer pass for ShaderFlagsAnalysis results.
@@ -92,19 +103,16 @@ class ShaderFlagsAnalysisPrinter
 /// This is required because the passes that will depend on this are codegen
 /// passes which run through the legacy pass manager.
 class ShaderFlagsAnalysisWrapper : public ModulePass {
-  ComputedShaderFlags Flags;
+  DXILModuleShaderFlagsInfo MSFI;
 
 public:
   static char ID;
 
   ShaderFlagsAnalysisWrapper() : ModulePass(ID) {}
 
-  const ComputedShaderFlags &getShaderFlags() { return Flags; }
+  const DXILModuleShaderFlagsInfo &getShaderFlags() { return MSFI; }
 
-  bool runOnModule(Module &M) override {
-    Flags = ComputedShaderFlags::computeFlags(M);
-    return false;
-  }
+  bool runOnModule(Module &M) override;
 
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     AU.setPreservesAll();
diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index be370e10df6943..2da4fe83a066c2 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -286,11 +286,6 @@ static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD,
   MDTuple *Properties = nullptr;
   if (ShaderFlags != 0) {
     SmallVector<Metadata *> MDVals;
-    // FIXME: ShaderFlagsAnalysis pass needs to collect and provide
-    // ShaderFlags for each entry function. Currently, ShaderFlags value
-    // provided by ShaderFlagsAnalysis pass is created by walking *all* the
-    // function instructions of the module. Is it is correct to use this value
-    // for metadata of the empty library entry?
     MDVals.append(
         getTagValueAsMetadata(EntryPropsTag::ShaderFlags, ShaderFlags, Ctx));
     Properties = MDNode::get(Ctx, MDVals);
@@ -302,7 +297,7 @@ static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD,
 
 static void translateMetadata(Module &M, const DXILResourceMap &DRM,
                               const Resources &MDResources,
-                              const ComputedShaderFlags &ShaderFlags,
+                              const DXILModuleShaderFlagsInfo &ShaderFlags,
                               const ModuleMetadataInfo &MMDI) {
   LLVMContext &Ctx = M.getContext();
   IRBuilder<> IRB(Ctx);
@@ -318,22 +313,37 @@ static void translateMetadata(Module &M, const DXILResourceMap &DRM,
   // See https://github.com/llvm/llvm-project/issues/57928
   MDTuple *Signatures = nullptr;
 
-  if (MMDI.ShaderProfile == Triple::EnvironmentType::Library)
+  if (MMDI.ShaderProfile == Triple::EnvironmentType::Library) {
+    // Create a consolidated shader flag mask of all functions in the library
+    // to be used as shader flags mask value associated with top-level library
+    // entry metadata.
+    uint64_t ConsolidatedMask = ShaderFlags.ModuleFlags;
+    for (const auto &FunFlags : ShaderFlags.FuncShaderFlagsMap) {
+      ConsolidatedMask |= FunFlags.second;
+    }
     EntryFnMDNodes.emplace_back(
-        emitTopLevelLibraryNode(M, ResourceMD, ShaderFlags));
-  else if (MMDI.EntryPropertyVec.size() > 1) {
+        emitTopLevelLibraryNode(M, ResourceMD, ConsolidatedMask));
+  } else if (MMDI.EntryPropertyVec.size() > 1) {
     M.getContext().diagnose(DiagnosticInfoTranslateMD(
         M, "Non-library shader: One and only one entry expected"));
   }
 
   for (const EntryProperties &EntryProp : MMDI.EntryPropertyVec) {
-    // FIXME: ShaderFlagsAnalysis pass needs to collect and provide
-    // ShaderFlags for each entry function. For now, assume shader flags value
-    // of entry functions being compiled for lib_* shader profile viz.,
-    // EntryPro.Entry is 0.
-    uint64_t EntryShaderFlags =
-        (MMDI.ShaderProfile == Triple::EnvironmentType::Library) ? 0
-                                                                 : ShaderFlags;
+    auto FSFIt = ShaderFlags.FuncShaderFlagsMap.find(EntryProp.Entry);
+    if (FSFIt == ShaderFlags.FuncShaderFlagsMap.end()) {
+      M.getContext().diagnose(DiagnosticInfoTranslateMD(
+          M, "Shader Flags of Function '" + Twine(EntryProp.Entry->getName()) +
+                 "' not found"));
+    }
+    // If ShaderProfile is Library, mask is already consolidated in the
+    // top-level library node. Hence it is not emitted.
+    uint64_t EntryShaderFlags = 0;
+    if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
+      // TODO: Create a consolidated shader flag mask of all the entry
+      // functions and its callees. The following is correct only if
+      // (*FSIt).first has no call instructions.
+      EntryShaderFlags = (*FSFIt).second | ShaderFlags.ModuleFlags;
+    }
     if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
       if (EntryProp.ShaderStage != MMDI.ShaderProfile) {
         M.getContext().diagnose(DiagnosticInfoTranslateMD(
@@ -361,7 +371,7 @@ PreservedAnalyses DXILTranslateMetadata::run(Module &M,
                                              ModuleAnalysisManager &MAM) {
   const DXILResourceMap &DRM = MAM.getResult<DXILResourceAnalysis>(M);
   const dxil::Resources &MDResources = MAM.getResult<DXILResourceMDAnalysis>(M);
-  const ComputedShaderFlags &ShaderFlags =
+  const DXILModuleShaderFlagsInfo &ShaderFlags =
       MAM.getResult<ShaderFlagsAnalysis>(M);
   const dxil::ModuleMetadataInfo MMDI = MAM.getResult<DXILMetadataAnalysis>(M);
 
@@ -393,7 +403,7 @@ class DXILTranslateMetadataLegacy : public ModulePass {
         getAnalysis<DXILResourceWrapperPass>().getResourceMap();
     const dxil::Resources &MDResources =
         getAnalysis<DXILResourceMDWrapper>().getDXILResource();
-    const ComputedShaderFlags &ShaderFlags =
+    const DXILModuleShaderFlagsInfo &ShaderFlags =
         getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags();
     dxil::ModuleMetadataInfo MMDI =
         getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();

>From 397f70bcb6646c7bb679551037e840a20f862173 Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Fri, 18 Oct 2024 16:19:25 -0400
Subject: [PATCH 2/4] clang-format changes

---
 llvm/lib/Target/DirectX/DXILShaderFlags.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index 8c590862008862..bb350c64b5c505 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -54,7 +54,7 @@ static DXILModuleShaderFlagsInfo computeFlags(Module &M) {
 }
 
 void ComputedShaderFlags::print(raw_ostream &OS) const {
-  uint64_t FlagVal = (uint64_t)*this;
+  uint64_t FlagVal = (uint64_t) * this;
   OS << formatv("; Shader Flags Value: {0:x8}\n;\n", FlagVal);
   if (FlagVal == 0)
     return;

>From ae373d466a52e69d7d4c5d66c8ad85aedb03f834 Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Fri, 25 Oct 2024 15:51:36 -0400
Subject: [PATCH 3/4] Use a SmallVector of pairs instead of DenseMap to collect
 Function pointers and corresponding shader flag masks. This follows the
 recommendations in LLVM Programmer's Manual as the current usage pattern has
 distinct phases of insertion of computed shader flags followed by querying.
 Upon insertion, the Smallvector is sorted and binary search is used for
 querying. Necessary comparison function of pairs is also implemented.

Added a simple DiagnosticInfoShaderFlags for emitting diagnostics.

Added tests to verify shader flags masks collected per-function.
---
 .../lib/Target/DirectX/DXContainerGlobals.cpp |   2 +-
 llvm/lib/Target/DirectX/DXILShaderFlags.cpp   | 130 +++++++++++++++---
 llvm/lib/Target/DirectX/DXILShaderFlags.h     |  15 +-
 .../Target/DirectX/DXILTranslateMetadata.cpp  |  13 +-
 .../ShaderFlags/double-extensions-obj-test.ll |  19 +++
 .../DirectX/ShaderFlags/double-extensions.ll  |  77 +++++++++--
 .../CodeGen/DirectX/ShaderFlags/doubles.ll    |   5 +-
 .../CodeGen/DirectX/ShaderFlags/no_flags.ll   |   6 +-
 8 files changed, 218 insertions(+), 49 deletions(-)
 create mode 100644 llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions-obj-test.ll

diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
index c7202cc04c26dc..81651c8cb787ab 100644
--- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
+++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
@@ -84,7 +84,7 @@ GlobalVariable *DXContainerGlobals::getFeatureFlags(Module &M) {
   // of the shader flags of all functions in the module. Need to verify
   // and modify the computation of feature flags to be used.
   uint64_t ConsolidatedFeatureFlags = 0;
-  for (const auto &FuncFlags : MSFI.FuncShaderFlagsMap) {
+  for (const auto &FuncFlags : MSFI.FuncShaderFlagsVec) {
     ConsolidatedFeatureFlags |= FuncFlags.second.getFeatureFlags();
   }
 
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index bb350c64b5c505..9afe48667ce8b4 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -13,43 +13,114 @@
 
 #include "DXILShaderFlags.h"
 #include "DirectX.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/IR/DiagnosticInfo.h"
+#include "llvm/IR/DiagnosticPrinter.h"
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Module.h"
 #include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/raw_ostream.h"
 
 using namespace llvm;
 using namespace llvm::dxil;
 
-static void updateFlags(DXILModuleShaderFlagsInfo &MSFI, const Instruction &I) {
-  ComputedShaderFlags &FSF = MSFI.FuncShaderFlagsMap[I.getFunction()];
+namespace {
+/// A simple Wrapper DiagnosticInfo that generates Module-level diagnostic
+/// for ShaderFlagsAnalysis pass
+class DiagnosticInfoShaderFlags : public DiagnosticInfo {
+private:
+  const Twine &Msg;
+  const Module &Mod;
+
+public:
+  /// \p M is the module for which the diagnostic is being emitted. \p Msg is
+  /// the message to show. Note that this class does not copy this message, so
+  /// this reference must be valid for the whole life time of the diagnostic.
+  DiagnosticInfoShaderFlags(const Module &M, const Twine &Msg,
+                            DiagnosticSeverity Severity = DS_Error)
+      : DiagnosticInfo(DK_Unsupported, Severity), Msg(Msg), Mod(M) {}
+
+  void print(DiagnosticPrinter &DP) const override {
+    DP << Mod.getName() << ": " << Msg << '\n';
+  }
+};
+} // namespace
+
+static void updateFlags(ComputedShaderFlags &CSF, const Instruction &I) {
   Type *Ty = I.getType();
-  if (Ty->isDoubleTy()) {
-    FSF.Doubles = true;
+  bool DoubleTyInUse = Ty->isDoubleTy();
+  for (Value *Op : I.operands()) {
+    DoubleTyInUse |= Op->getType()->isDoubleTy();
+  }
+
+  if (DoubleTyInUse) {
+    CSF.Doubles = true;
     switch (I.getOpcode()) {
     case Instruction::FDiv:
     case Instruction::UIToFP:
     case Instruction::SIToFP:
     case Instruction::FPToUI:
     case Instruction::FPToSI:
-      FSF.DX11_1_DoubleExtensions = true;
+      // TODO: To be set if I is a call to DXIL intrinsic DXIL::Opcode::Fma
+      CSF.DX11_1_DoubleExtensions = true;
       break;
     }
   }
 }
 
+static bool compareFuncSFPairs(const FuncShaderFlagsMask &First,
+                               const FuncShaderFlagsMask &Second) {
+  // Construct string representation of the functions in each pair
+  // as "retTypefunctionNamearg1Typearg2Ty..." where the function signature is
+  // retType functionName(arg1Type, arg2Ty,...).  Spaces, braces and commas are
+  //  omitted in the string representation of the signature. This allows
+  // determining a consistent lexicographical order of all functions by their
+  // signatures.
+  std::string FirstFunSig;
+  std::string SecondFunSig;
+  raw_string_ostream FRSO(FirstFunSig);
+  raw_string_ostream SRSO(SecondFunSig);
+
+  // Return type
+  First.first->getReturnType()->print(FRSO);
+  Second.first->getReturnType()->print(SRSO);
+  // Function name
+  FRSO << First.first->getName();
+  SRSO << Second.first->getName();
+  // Argument types
+  for (const Argument &Arg : First.first->args()) {
+    Arg.getType()->print(FRSO);
+  }
+  for (const Argument &Arg : Second.first->args()) {
+    Arg.getType()->print(SRSO);
+  }
+  FRSO.flush();
+  SRSO.flush();
+
+  return FRSO.str().compare(SRSO.str()) < 0;
+}
+
 static DXILModuleShaderFlagsInfo computeFlags(Module &M) {
   DXILModuleShaderFlagsInfo MSFI;
-  for (const auto &F : M) {
+  for (auto &F : M) {
     if (F.isDeclaration())
       continue;
-    if (!MSFI.FuncShaderFlagsMap.contains(&F)) {
-      ComputedShaderFlags CSF{};
-      MSFI.FuncShaderFlagsMap[&F] = CSF;
+    // Each of the functions in a module are unique. Hence no prior shader flags
+    // mask of the function should be present.
+    if (MSFI.hasShaderFlagsMask(&F)) {
+      M.getContext().diagnose(DiagnosticInfoShaderFlags(
+          M, "Shader Flags mask for Function '" + Twine(F.getName()) +
+                 "' already exits"));
     }
+    ComputedShaderFlags CSF{};
     for (const auto &BB : F)
       for (const auto &I : BB)
-        updateFlags(MSFI, I);
+        updateFlags(CSF, I);
+    // Insert shader flag mask for function F
+    MSFI.FuncShaderFlagsVec.push_back({&F, CSF});
   }
+  // Sort MSFI.FuncShaderFlagsVec for later lookup that uses binary search
+  llvm::sort(MSFI.FuncShaderFlagsVec, compareFuncSFPairs);
   return MSFI;
 }
 
@@ -71,6 +142,38 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
   OS << ";\n";
 }
 
+void DXILModuleShaderFlagsInfo::print(raw_ostream &OS) const {
+  OS << "; Shader Flags mask for Module:\n";
+  ModuleFlags.print(OS);
+  for (auto SF : FuncShaderFlagsVec) {
+    OS << "; Shader Flags mask for Function: " << SF.first->getName() << "\n";
+    SF.second.print(OS);
+  }
+}
+
+const ComputedShaderFlags
+DXILModuleShaderFlagsInfo::getShaderFlagsMask(const Function *Func) const {
+  FuncShaderFlagsMask V{Func, {}};
+  auto Iter = llvm::lower_bound(FuncShaderFlagsVec, V, compareFuncSFPairs);
+  if (Iter == FuncShaderFlagsVec.end()) {
+    Func->getContext().diagnose(DiagnosticInfoShaderFlags(
+        *(Func->getParent()), "Shader Flags information of Function '" +
+                                  Twine(Func->getName()) + "' not found"));
+  }
+  if (Iter->first != Func) {
+    Func->getContext().diagnose(DiagnosticInfoShaderFlags(
+        *(Func->getParent()),
+        "Inconsistent Shader Flags information of Function '" +
+            Twine(Func->getName()) + "' retrieved"));
+  }
+  return Iter->second;
+}
+
+bool DXILModuleShaderFlagsInfo::hasShaderFlagsMask(const Function *Func) const {
+  FuncShaderFlagsMask V{Func, {}};
+  return llvm::binary_search(FuncShaderFlagsVec, V);
+}
+
 AnalysisKey ShaderFlagsAnalysis::Key;
 
 DXILModuleShaderFlagsInfo ShaderFlagsAnalysis::run(Module &M,
@@ -86,12 +189,7 @@ bool ShaderFlagsAnalysisWrapper::runOnModule(Module &M) {
 PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
                                                   ModuleAnalysisManager &AM) {
   DXILModuleShaderFlagsInfo Flags = AM.getResult<ShaderFlagsAnalysis>(M);
-  OS << "; Shader Flags mask for Module:\n";
-  Flags.ModuleFlags.print(OS);
-  for (auto SF : Flags.FuncShaderFlagsMap) {
-    OS << "; Shader Flags mash for Function: " << SF.first->getName() << "\n";
-    SF.second.print(OS);
-  }
+  Flags.print(OS);
   return PreservedAnalyses::all();
 }
 
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h
index 6f81ff74384d0c..55967f03ca4de6 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.h
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h
@@ -14,7 +14,6 @@
 #ifndef LLVM_TARGET_DIRECTX_DXILSHADERFLAGS_H
 #define LLVM_TARGET_DIRECTX_DXILSHADERFLAGS_H
 
-#include "llvm/ADT/DenseMap.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/Pass.h"
@@ -66,14 +65,18 @@ struct ComputedShaderFlags {
   LLVM_DUMP_METHOD void dump() const { print(); }
 };
 
-using FunctionShaderFlagsMap =
-    SmallDenseMap<Function const *, ComputedShaderFlags>;
+using FuncShaderFlagsMask = std::pair<Function const *, ComputedShaderFlags>;
+using FunctionShaderFlagsVec = SmallVector<FuncShaderFlagsMask>;
 struct DXILModuleShaderFlagsInfo {
   // Shader Flag mask representing module-level properties
   ComputedShaderFlags ModuleFlags;
-  // Map representing shader flag mask representing properties of each of the
-  // functions in the module
-  FunctionShaderFlagsMap FuncShaderFlagsMap;
+  // Vector of Function-Shader Flag mask pairs representing properties of each
+  // of the functions in the module
+  FunctionShaderFlagsVec FuncShaderFlagsVec;
+
+  const ComputedShaderFlags getShaderFlagsMask(const Function *Func) const;
+  bool hasShaderFlagsMask(const Function *Func) const;
+  void print(raw_ostream &OS = dbgs()) const;
 };
 
 class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> {
diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index 2da4fe83a066c2..f3593325b26415 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -318,7 +318,7 @@ static void translateMetadata(Module &M, const DXILResourceMap &DRM,
     // to be used as shader flags mask value associated with top-level library
     // entry metadata.
     uint64_t ConsolidatedMask = ShaderFlags.ModuleFlags;
-    for (const auto &FunFlags : ShaderFlags.FuncShaderFlagsMap) {
+    for (const auto &FunFlags : ShaderFlags.FuncShaderFlagsVec) {
       ConsolidatedMask |= FunFlags.second;
     }
     EntryFnMDNodes.emplace_back(
@@ -329,20 +329,15 @@ static void translateMetadata(Module &M, const DXILResourceMap &DRM,
   }
 
   for (const EntryProperties &EntryProp : MMDI.EntryPropertyVec) {
-    auto FSFIt = ShaderFlags.FuncShaderFlagsMap.find(EntryProp.Entry);
-    if (FSFIt == ShaderFlags.FuncShaderFlagsMap.end()) {
-      M.getContext().diagnose(DiagnosticInfoTranslateMD(
-          M, "Shader Flags of Function '" + Twine(EntryProp.Entry->getName()) +
-                 "' not found"));
-    }
+    ComputedShaderFlags ECSF = ShaderFlags.getShaderFlagsMask(EntryProp.Entry);
     // If ShaderProfile is Library, mask is already consolidated in the
     // top-level library node. Hence it is not emitted.
     uint64_t EntryShaderFlags = 0;
     if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
       // TODO: Create a consolidated shader flag mask of all the entry
       // functions and its callees. The following is correct only if
-      // (*FSIt).first has no call instructions.
-      EntryShaderFlags = (*FSFIt).second | ShaderFlags.ModuleFlags;
+      // EntryProp.Entry has no call instructions.
+      EntryShaderFlags = ECSF | ShaderFlags.ModuleFlags;
     }
     if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
       if (EntryProp.ShaderStage != MMDI.ShaderProfile) {
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions-obj-test.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions-obj-test.ll
new file mode 100644
index 00000000000000..2b6b39a9c2d37e
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions-obj-test.ll
@@ -0,0 +1,19 @@
+; RUN: llc %s --filetype=obj -o - | obj2yaml | FileCheck %s --check-prefix=DXC
+
+target triple = "dxil-pc-shadermodel6.7-library"
+define double @div(double %a, double %b) #0 {
+  %res = fdiv double %a, %b
+  ret double %res
+}
+
+attributes #0 = { convergent norecurse nounwind "hlsl.export"}
+
+; DXC: - Name:            SFI0
+; DXC-NEXT:     Size:            8
+; DXC-NEXT:     Flags:
+; DXC-NEXT:       Doubles:         true
+; DXC-NOT:   {{[A-Za-z]+: +true}}
+; DXC:            DX11_1_DoubleExtensions:         true
+; DXC-NOT:   {{[A-Za-z]+: +true}}
+; DXC:       NextUnusedBit:   false
+; DXC: ...
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
index a8d5f9c78f0b43..7627e160514436 100644
--- a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
@@ -1,27 +1,74 @@
 ; RUN: opt -S --passes="print-dx-shader-flags" 2>&1 %s | FileCheck %s
-; RUN: llc %s --filetype=obj -o - | obj2yaml | FileCheck %s --check-prefix=DXC
 
 target triple = "dxil-pc-shadermodel6.7-library"
 
-; CHECK: ; Shader Flags Value: 0x00000044
-; CHECK: ; Note: shader requires additional functionality:
+; CHECK: ; Shader Flags mask for Module:
+; CHECK-NEXT: ; Shader Flags Value: 0x00000000
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Shader Flags mask for Function: test_fdiv_double
+; CHECK-NEXT: ; Shader Flags Value: 0x00000044
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Note: shader requires additional functionality:
 ; CHECK-NEXT: ;       Double-precision floating point
 ; CHECK-NEXT: ;       Double-precision extensions for 11.1
 ; CHECK-NEXT: ; Note: extra DXIL module flags:
-; CHECK-NEXT: {{^;$}}
-define double @div(double %a, double %b) #0 {
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Shader Flags mask for Function: test_sitofp_i64
+; CHECK-NEXT: ; Shader Flags Value: 0x00000044
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Note: shader requires additional functionality:
+; CHECK-NEXT: ;       Double-precision floating point
+; CHECK-NEXT: ;       Double-precision extensions for 11.1
+; CHECK-NEXT: ; Note: extra DXIL module flags:
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Shader Flags mask for Function: test_uitofp_i64
+; CHECK-NEXT: ; Shader Flags Value: 0x00000044
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Note: shader requires additional functionality:
+; CHECK-NEXT: ;       Double-precision floating point
+; CHECK-NEXT: ;       Double-precision extensions for 11.1
+; CHECK-NEXT: ; Note: extra DXIL module flags:
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Shader Flags mask for Function: test_fptoui_i32
+; CHECK-NEXT: ; Shader Flags Value: 0x00000044
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Note: shader requires additional functionality:
+; CHECK-NEXT: ;       Double-precision floating point
+; CHECK-NEXT: ;       Double-precision extensions for 11.1
+; CHECK-NEXT: ; Note: extra DXIL module flags:
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Shader Flags mask for Function: test_fptosi_i64
+; CHECK-NEXT: ; Shader Flags Value: 0x00000044
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Note: shader requires additional functionality:
+; CHECK-NEXT: ;       Double-precision floating point
+; CHECK-NEXT: ;       Double-precision extensions for 11.1
+; CHECK-NEXT: ; Note: extra DXIL module flags:
+; CHECK-NEXT: ;
+
+define double @test_fdiv_double(double %a, double %b) #0 {
   %res = fdiv double %a, %b
   ret double %res
 }
 
-attributes #0 = { convergent norecurse nounwind "hlsl.export"}
+define double @test_uitofp_i64(i64 %a) #0 {
+  %r = uitofp i64 %a to double
+  ret double %r
+}
+
+define double @test_sitofp_i64(i64 %a) #0 {
+  %r = sitofp i64 %a to double
+  ret double %r
+}
 
-; DXC: - Name:            SFI0
-; DXC-NEXT:     Size:            8
-; DXC-NEXT:     Flags:
-; DXC-NEXT:       Doubles:         true
-; DXC-NOT:   {{[A-Za-z]+: +true}}
-; DXC:            DX11_1_DoubleExtensions:         true
-; DXC-NOT:   {{[A-Za-z]+: +true}}
-; DXC:       NextUnusedBit:   false
-; DXC: ...
+define i32 @test_fptoui_i32(double %a) #0 {
+  %r = fptoui double %a to i32
+  ret i32 %r
+}
+
+define i64 @test_fptosi_i64(double %a) #0 {
+  %r = fptosi double %a to i64
+  ret i64 %r
+}
+
+attributes #0 = { convergent norecurse nounwind "hlsl.export"}
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/doubles.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/doubles.ll
index e9b44240e10b9b..b4276c2144af97 100644
--- a/llvm/test/CodeGen/DirectX/ShaderFlags/doubles.ll
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/doubles.ll
@@ -3,7 +3,10 @@
 
 target triple = "dxil-pc-shadermodel6.7-library"
 
-; CHECK: ; Shader Flags Value: 0x00000004
+; CHECK: ; Shader Flags mask for Module:
+; CHECK-NEXT: ; Shader Flags Value: 0x00000000
+; CHECK: ; Shader Flags mask for Function: add
+; CHECK-NEXT: ; Shader Flags Value: 0x00000004
 ; CHECK: ; Note: shader requires additional functionality:
 ; CHECK-NEXT: ;       Double-precision floating point
 ; CHECK-NEXT: ; Note: extra DXIL module flags:
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/no_flags.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/no_flags.ll
index f7baa1b64f9cd1..df40fe9d86eda1 100644
--- a/llvm/test/CodeGen/DirectX/ShaderFlags/no_flags.ll
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/no_flags.ll
@@ -2,7 +2,11 @@
 
 target triple = "dxil-pc-shadermodel6.7-library"
 
-; CHECK: ; Shader Flags Value: 0x00000000
+; CHECK: ; Shader Flags mask for Module:
+; CHECK-NEXT: ; Shader Flags Value: 0x00000000
+;
+; CHECK: ; Shader Flags mask for Function: add
+; CHECK-NEXT: ; Shader Flags Value: 0x00000000
 define i32 @add(i32 %a, i32 %b) {
   %sum = add i32 %a, %b
   ret i32 %sum

>From c02053ad2b7d2a489fa04af14b02f6f2b11818f1 Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Mon, 28 Oct 2024 16:36:31 -0400
Subject: [PATCH 4/4] Compare functions by their names instead of constructing
 a pseudo-signature Non-empty Function names are unique in LLVM IR.

Update the expected test output accordingly
---
 llvm/lib/Target/DirectX/DXILShaderFlags.cpp   | 35 +++----------------
 .../DirectX/ShaderFlags/double-extensions.ll  |  8 ++---
 2 files changed, 8 insertions(+), 35 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index 9afe48667ce8b4..4fb008b11a0a2e 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -70,34 +70,7 @@ static void updateFlags(ComputedShaderFlags &CSF, const Instruction &I) {
 
 static bool compareFuncSFPairs(const FuncShaderFlagsMask &First,
                                const FuncShaderFlagsMask &Second) {
-  // Construct string representation of the functions in each pair
-  // as "retTypefunctionNamearg1Typearg2Ty..." where the function signature is
-  // retType functionName(arg1Type, arg2Ty,...).  Spaces, braces and commas are
-  //  omitted in the string representation of the signature. This allows
-  // determining a consistent lexicographical order of all functions by their
-  // signatures.
-  std::string FirstFunSig;
-  std::string SecondFunSig;
-  raw_string_ostream FRSO(FirstFunSig);
-  raw_string_ostream SRSO(SecondFunSig);
-
-  // Return type
-  First.first->getReturnType()->print(FRSO);
-  Second.first->getReturnType()->print(SRSO);
-  // Function name
-  FRSO << First.first->getName();
-  SRSO << Second.first->getName();
-  // Argument types
-  for (const Argument &Arg : First.first->args()) {
-    Arg.getType()->print(FRSO);
-  }
-  for (const Argument &Arg : Second.first->args()) {
-    Arg.getType()->print(SRSO);
-  }
-  FRSO.flush();
-  SRSO.flush();
-
-  return FRSO.str().compare(SRSO.str()) < 0;
+  return (First.first->getName().compare(Second.first->getName()) < 0);
 }
 
 static DXILModuleShaderFlagsInfo computeFlags(Module &M) {
@@ -108,9 +81,9 @@ static DXILModuleShaderFlagsInfo computeFlags(Module &M) {
     // Each of the functions in a module are unique. Hence no prior shader flags
     // mask of the function should be present.
     if (MSFI.hasShaderFlagsMask(&F)) {
-      M.getContext().diagnose(DiagnosticInfoShaderFlags(
-          M, "Shader Flags mask for Function '" + Twine(F.getName()) +
-                 "' already exits"));
+      M.getContext().diagnose(
+          DiagnosticInfoShaderFlags(M, "Shader Flags mask for Function '" +
+                                           F.getName() + "' already exists"));
     }
     ComputedShaderFlags CSF{};
     for (const auto &BB : F)
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
index 7627e160514436..dc4a90194262a0 100644
--- a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
@@ -13,7 +13,7 @@ target triple = "dxil-pc-shadermodel6.7-library"
 ; CHECK-NEXT: ;       Double-precision extensions for 11.1
 ; CHECK-NEXT: ; Note: extra DXIL module flags:
 ; CHECK-NEXT: ;
-; CHECK-NEXT: ; Shader Flags mask for Function: test_sitofp_i64
+; CHECK-NEXT: ; Shader Flags mask for Function: test_fptosi_i64
 ; CHECK-NEXT: ; Shader Flags Value: 0x00000044
 ; CHECK-NEXT: ;
 ; CHECK-NEXT: ; Note: shader requires additional functionality:
@@ -21,7 +21,7 @@ target triple = "dxil-pc-shadermodel6.7-library"
 ; CHECK-NEXT: ;       Double-precision extensions for 11.1
 ; CHECK-NEXT: ; Note: extra DXIL module flags:
 ; CHECK-NEXT: ;
-; CHECK-NEXT: ; Shader Flags mask for Function: test_uitofp_i64
+; CHECK-NEXT: ; Shader Flags mask for Function: test_fptoui_i32
 ; CHECK-NEXT: ; Shader Flags Value: 0x00000044
 ; CHECK-NEXT: ;
 ; CHECK-NEXT: ; Note: shader requires additional functionality:
@@ -29,7 +29,7 @@ target triple = "dxil-pc-shadermodel6.7-library"
 ; CHECK-NEXT: ;       Double-precision extensions for 11.1
 ; CHECK-NEXT: ; Note: extra DXIL module flags:
 ; CHECK-NEXT: ;
-; CHECK-NEXT: ; Shader Flags mask for Function: test_fptoui_i32
+; CHECK-NEXT: ; Shader Flags mask for Function: test_sitofp_i64
 ; CHECK-NEXT: ; Shader Flags Value: 0x00000044
 ; CHECK-NEXT: ;
 ; CHECK-NEXT: ; Note: shader requires additional functionality:
@@ -37,7 +37,7 @@ target triple = "dxil-pc-shadermodel6.7-library"
 ; CHECK-NEXT: ;       Double-precision extensions for 11.1
 ; CHECK-NEXT: ; Note: extra DXIL module flags:
 ; CHECK-NEXT: ;
-; CHECK-NEXT: ; Shader Flags mask for Function: test_fptosi_i64
+; CHECK-NEXT: ; Shader Flags mask for Function: test_uitofp_i64
 ; CHECK-NEXT: ; Shader Flags Value: 0x00000044
 ; CHECK-NEXT: ;
 ; CHECK-NEXT: ; Note: shader requires additional functionality:



More information about the llvm-commits mailing list