[llvm] [DirectX] Infrastructure to collect shader flags for each function (PR #112967)

S. Bharadwaj Yadavalli via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 15 12:58:57 PST 2024


https://github.com/bharadwajy updated https://github.com/llvm/llvm-project/pull/112967

>From 3da01ee7ee64c099344661d48e14960e04d19cfc Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Wed, 16 Oct 2024 14:57:57 -0400
Subject: [PATCH 01/12] [NFC][DirectX] Infrastructure to collect shader flags
 for each function

Currently, ShaderFlagsAnalysis pass represents various module-level
properties as well as function-level properties of a DXIL Module
using a single mask. However, separate flags to represent module-level
properties and function-level properties are needed for accurate computation
of shader flags mask, such as for entry function metadata creation.

This change introduces a structure that allows separate representation of

(a) shader flag mask to represent module properties
(b) a map of function to shader flag mask that represent function properties

instead of a single shader flag mask that represents module properties
and properties of all function. The result type of ShaderFlagsAnalysis
pass is changed to newly-defined structure type instead of a single shader
flags mask.

This seperation allows accurate computation of shader flags of an entry
function for use during its metadata generation (DXILTranslateMetadata pass)
and its feature flags in DX container globals construction (DXContainerGlobals
pass) based on the shader flags mask of functions called in entry function.
However, note that the change to implement such callee-based shader flags mask
computation is planned in a follow-on PR. Consequently, this PR changes shader
flag mask computation in DXILTranslateMetadata and DXContainerGlobals passes
to simply be a union of module flags and shader flags of all functions, thereby
retaining the existing effect of using a single shader flag mask.
---
 .../lib/Target/DirectX/DXContainerGlobals.cpp | 15 ++++--
 llvm/lib/Target/DirectX/DXILShaderFlags.cpp   | 46 +++++++++++++------
 llvm/lib/Target/DirectX/DXILShaderFlags.h     | 26 +++++++----
 .../Target/DirectX/DXILTranslateMetadata.cpp  | 46 +++++++++++--------
 4 files changed, 87 insertions(+), 46 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
index 2c11373504e8c7..c7202cc04c26dc 100644
--- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
+++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
@@ -78,13 +78,18 @@ bool DXContainerGlobals::runOnModule(Module &M) {
 }
 
 GlobalVariable *DXContainerGlobals::getFeatureFlags(Module &M) {
-  const uint64_t FeatureFlags =
-      static_cast<uint64_t>(getAnalysis<ShaderFlagsAnalysisWrapper>()
-                                .getShaderFlags()
-                                .getFeatureFlags());
+  const DXILModuleShaderFlagsInfo &MSFI =
+      getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags();
+  // TODO: Feature flags mask is obtained as a collection of feature flags
+  // of the shader flags of all functions in the module. Need to verify
+  // and modify the computation of feature flags to be used.
+  uint64_t ConsolidatedFeatureFlags = 0;
+  for (const auto &FuncFlags : MSFI.FuncShaderFlagsMap) {
+    ConsolidatedFeatureFlags |= FuncFlags.second.getFeatureFlags();
+  }
 
   Constant *FeatureFlagsConstant =
-      ConstantInt::get(M.getContext(), APInt(64, FeatureFlags));
+      ConstantInt::get(M.getContext(), APInt(64, ConsolidatedFeatureFlags));
   return buildContainerGlobal(M, FeatureFlagsConstant, "dx.sfi0", "SFI0");
 }
 
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index 9fa137b4c025e1..8c590862008862 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -20,33 +20,41 @@
 using namespace llvm;
 using namespace llvm::dxil;
 
-static void updateFlags(ComputedShaderFlags &Flags, const Instruction &I) {
+static void updateFlags(DXILModuleShaderFlagsInfo &MSFI, const Instruction &I) {
+  ComputedShaderFlags &FSF = MSFI.FuncShaderFlagsMap[I.getFunction()];
   Type *Ty = I.getType();
   if (Ty->isDoubleTy()) {
-    Flags.Doubles = true;
+    FSF.Doubles = true;
     switch (I.getOpcode()) {
     case Instruction::FDiv:
     case Instruction::UIToFP:
     case Instruction::SIToFP:
     case Instruction::FPToUI:
     case Instruction::FPToSI:
-      Flags.DX11_1_DoubleExtensions = true;
+      FSF.DX11_1_DoubleExtensions = true;
       break;
     }
   }
 }
 
-ComputedShaderFlags ComputedShaderFlags::computeFlags(Module &M) {
-  ComputedShaderFlags Flags;
-  for (const auto &F : M)
+static DXILModuleShaderFlagsInfo computeFlags(Module &M) {
+  DXILModuleShaderFlagsInfo MSFI;
+  for (const auto &F : M) {
+    if (F.isDeclaration())
+      continue;
+    if (!MSFI.FuncShaderFlagsMap.contains(&F)) {
+      ComputedShaderFlags CSF{};
+      MSFI.FuncShaderFlagsMap[&F] = CSF;
+    }
     for (const auto &BB : F)
       for (const auto &I : BB)
-        updateFlags(Flags, I);
-  return Flags;
+        updateFlags(MSFI, I);
+  }
+  return MSFI;
 }
 
 void ComputedShaderFlags::print(raw_ostream &OS) const {
-  uint64_t FlagVal = (uint64_t) * this;
+  uint64_t FlagVal = (uint64_t)*this;
   OS << formatv("; Shader Flags Value: {0:x8}\n;\n", FlagVal);
   if (FlagVal == 0)
     return;
@@ -65,15 +73,25 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
 
 AnalysisKey ShaderFlagsAnalysis::Key;
 
-ComputedShaderFlags ShaderFlagsAnalysis::run(Module &M,
-                                             ModuleAnalysisManager &AM) {
-  return ComputedShaderFlags::computeFlags(M);
+DXILModuleShaderFlagsInfo ShaderFlagsAnalysis::run(Module &M,
+                                                   ModuleAnalysisManager &AM) {
+  return computeFlags(M);
+}
+
+bool ShaderFlagsAnalysisWrapper::runOnModule(Module &M) {
+  MSFI = computeFlags(M);
+  return false;
 }
 
 PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
                                                   ModuleAnalysisManager &AM) {
-  ComputedShaderFlags Flags = AM.getResult<ShaderFlagsAnalysis>(M);
-  Flags.print(OS);
+  DXILModuleShaderFlagsInfo Flags = AM.getResult<ShaderFlagsAnalysis>(M);
+  OS << "; Shader Flags mask for Module:\n";
+  Flags.ModuleFlags.print(OS);
+  for (auto SF : Flags.FuncShaderFlagsMap) {
+    OS << "; Shader Flags mash for Function: " << SF.first->getName() << "\n";
+    SF.second.print(OS);
+  }
   return PreservedAnalyses::all();
 }
 
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h
index 1df7d27de13d3c..6f81ff74384d0c 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.h
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h
@@ -14,6 +14,8 @@
 #ifndef LLVM_TARGET_DIRECTX_DXILSHADERFLAGS_H
 #define LLVM_TARGET_DIRECTX_DXILSHADERFLAGS_H
 
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/IR/Function.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/Compiler.h"
@@ -60,11 +62,20 @@ struct ComputedShaderFlags {
     return FeatureFlags;
   }
 
-  static ComputedShaderFlags computeFlags(Module &M);
   void print(raw_ostream &OS = dbgs()) const;
   LLVM_DUMP_METHOD void dump() const { print(); }
 };
 
+using FunctionShaderFlagsMap =
+    SmallDenseMap<Function const *, ComputedShaderFlags>;
+struct DXILModuleShaderFlagsInfo {
+  // Shader Flag mask representing module-level properties
+  ComputedShaderFlags ModuleFlags;
+  // Map representing shader flag mask representing properties of each of the
+  // functions in the module
+  FunctionShaderFlagsMap FuncShaderFlagsMap;
+};
+
 class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> {
   friend AnalysisInfoMixin<ShaderFlagsAnalysis>;
   static AnalysisKey Key;
@@ -72,9 +83,9 @@ class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> {
 public:
   ShaderFlagsAnalysis() = default;
 
-  using Result = ComputedShaderFlags;
+  using Result = DXILModuleShaderFlagsInfo;
 
-  ComputedShaderFlags run(Module &M, ModuleAnalysisManager &AM);
+  DXILModuleShaderFlagsInfo run(Module &M, ModuleAnalysisManager &AM);
 };
 
 /// Printer pass for ShaderFlagsAnalysis results.
@@ -92,19 +103,16 @@ class ShaderFlagsAnalysisPrinter
 /// This is required because the passes that will depend on this are codegen
 /// passes which run through the legacy pass manager.
 class ShaderFlagsAnalysisWrapper : public ModulePass {
-  ComputedShaderFlags Flags;
+  DXILModuleShaderFlagsInfo MSFI;
 
 public:
   static char ID;
 
   ShaderFlagsAnalysisWrapper() : ModulePass(ID) {}
 
-  const ComputedShaderFlags &getShaderFlags() { return Flags; }
+  const DXILModuleShaderFlagsInfo &getShaderFlags() { return MSFI; }
 
-  bool runOnModule(Module &M) override {
-    Flags = ComputedShaderFlags::computeFlags(M);
-    return false;
-  }
+  bool runOnModule(Module &M) override;
 
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     AU.setPreservesAll();
diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index be370e10df6943..2da4fe83a066c2 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -286,11 +286,6 @@ static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD,
   MDTuple *Properties = nullptr;
   if (ShaderFlags != 0) {
     SmallVector<Metadata *> MDVals;
-    // FIXME: ShaderFlagsAnalysis pass needs to collect and provide
-    // ShaderFlags for each entry function. Currently, ShaderFlags value
-    // provided by ShaderFlagsAnalysis pass is created by walking *all* the
-    // function instructions of the module. Is it is correct to use this value
-    // for metadata of the empty library entry?
     MDVals.append(
         getTagValueAsMetadata(EntryPropsTag::ShaderFlags, ShaderFlags, Ctx));
     Properties = MDNode::get(Ctx, MDVals);
@@ -302,7 +297,7 @@ static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD,
 
 static void translateMetadata(Module &M, const DXILResourceMap &DRM,
                               const Resources &MDResources,
-                              const ComputedShaderFlags &ShaderFlags,
+                              const DXILModuleShaderFlagsInfo &ShaderFlags,
                               const ModuleMetadataInfo &MMDI) {
   LLVMContext &Ctx = M.getContext();
   IRBuilder<> IRB(Ctx);
@@ -318,22 +313,37 @@ static void translateMetadata(Module &M, const DXILResourceMap &DRM,
   // See https://github.com/llvm/llvm-project/issues/57928
   MDTuple *Signatures = nullptr;
 
-  if (MMDI.ShaderProfile == Triple::EnvironmentType::Library)
+  if (MMDI.ShaderProfile == Triple::EnvironmentType::Library) {
+    // Create a consolidated shader flag mask of all functions in the library
+    // to be used as shader flags mask value associated with top-level library
+    // entry metadata.
+    uint64_t ConsolidatedMask = ShaderFlags.ModuleFlags;
+    for (const auto &FunFlags : ShaderFlags.FuncShaderFlagsMap) {
+      ConsolidatedMask |= FunFlags.second;
+    }
     EntryFnMDNodes.emplace_back(
-        emitTopLevelLibraryNode(M, ResourceMD, ShaderFlags));
-  else if (MMDI.EntryPropertyVec.size() > 1) {
+        emitTopLevelLibraryNode(M, ResourceMD, ConsolidatedMask));
+  } else if (MMDI.EntryPropertyVec.size() > 1) {
     M.getContext().diagnose(DiagnosticInfoTranslateMD(
         M, "Non-library shader: One and only one entry expected"));
   }
 
   for (const EntryProperties &EntryProp : MMDI.EntryPropertyVec) {
-    // FIXME: ShaderFlagsAnalysis pass needs to collect and provide
-    // ShaderFlags for each entry function. For now, assume shader flags value
-    // of entry functions being compiled for lib_* shader profile viz.,
-    // EntryPro.Entry is 0.
-    uint64_t EntryShaderFlags =
-        (MMDI.ShaderProfile == Triple::EnvironmentType::Library) ? 0
-                                                                 : ShaderFlags;
+    auto FSFIt = ShaderFlags.FuncShaderFlagsMap.find(EntryProp.Entry);
+    if (FSFIt == ShaderFlags.FuncShaderFlagsMap.end()) {
+      M.getContext().diagnose(DiagnosticInfoTranslateMD(
+          M, "Shader Flags of Function '" + Twine(EntryProp.Entry->getName()) +
+                 "' not found"));
+    }
+    // If ShaderProfile is Library, mask is already consolidated in the
+    // top-level library node. Hence it is not emitted.
+    uint64_t EntryShaderFlags = 0;
+    if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
+      // TODO: Create a consolidated shader flag mask of all the entry
+      // functions and its callees. The following is correct only if
+      // (*FSIt).first has no call instructions.
+      EntryShaderFlags = (*FSFIt).second | ShaderFlags.ModuleFlags;
+    }
     if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
       if (EntryProp.ShaderStage != MMDI.ShaderProfile) {
         M.getContext().diagnose(DiagnosticInfoTranslateMD(
@@ -361,7 +371,7 @@ PreservedAnalyses DXILTranslateMetadata::run(Module &M,
                                              ModuleAnalysisManager &MAM) {
   const DXILResourceMap &DRM = MAM.getResult<DXILResourceAnalysis>(M);
   const dxil::Resources &MDResources = MAM.getResult<DXILResourceMDAnalysis>(M);
-  const ComputedShaderFlags &ShaderFlags =
+  const DXILModuleShaderFlagsInfo &ShaderFlags =
       MAM.getResult<ShaderFlagsAnalysis>(M);
   const dxil::ModuleMetadataInfo MMDI = MAM.getResult<DXILMetadataAnalysis>(M);
 
@@ -393,7 +403,7 @@ class DXILTranslateMetadataLegacy : public ModulePass {
         getAnalysis<DXILResourceWrapperPass>().getResourceMap();
     const dxil::Resources &MDResources =
         getAnalysis<DXILResourceMDWrapper>().getDXILResource();
-    const ComputedShaderFlags &ShaderFlags =
+    const DXILModuleShaderFlagsInfo &ShaderFlags =
         getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags();
     dxil::ModuleMetadataInfo MMDI =
         getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();

>From 397f70bcb6646c7bb679551037e840a20f862173 Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Fri, 18 Oct 2024 16:19:25 -0400
Subject: [PATCH 02/12] clang-format changes

---
 llvm/lib/Target/DirectX/DXILShaderFlags.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index 8c590862008862..bb350c64b5c505 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -54,7 +54,7 @@ static DXILModuleShaderFlagsInfo computeFlags(Module &M) {
 }
 
 void ComputedShaderFlags::print(raw_ostream &OS) const {
-  uint64_t FlagVal = (uint64_t)*this;
+  uint64_t FlagVal = (uint64_t) * this;
   OS << formatv("; Shader Flags Value: {0:x8}\n;\n", FlagVal);
   if (FlagVal == 0)
     return;

>From ae373d466a52e69d7d4c5d66c8ad85aedb03f834 Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Fri, 25 Oct 2024 15:51:36 -0400
Subject: [PATCH 03/12] Use a SmallVector of pairs instead of DenseMap to
 collect Function pointers and corresponding shader flag masks. This follows
 the recommendations in LLVM Programmer's Manual as the current usage pattern
 has distinct phases of insertion of computed shader flags followed by
 querying. Upon insertion, the Smallvector is sorted and binary search is used
 for querying. Necessary comparison function of pairs is also implemented.

Added a simple DiagnosticInfoShaderFlags for emitting diagnostics.

Added tests to verify shader flags masks collected per-function.
---
 .../lib/Target/DirectX/DXContainerGlobals.cpp |   2 +-
 llvm/lib/Target/DirectX/DXILShaderFlags.cpp   | 130 +++++++++++++++---
 llvm/lib/Target/DirectX/DXILShaderFlags.h     |  15 +-
 .../Target/DirectX/DXILTranslateMetadata.cpp  |  13 +-
 .../ShaderFlags/double-extensions-obj-test.ll |  19 +++
 .../DirectX/ShaderFlags/double-extensions.ll  |  77 +++++++++--
 .../CodeGen/DirectX/ShaderFlags/doubles.ll    |   5 +-
 .../CodeGen/DirectX/ShaderFlags/no_flags.ll   |   6 +-
 8 files changed, 218 insertions(+), 49 deletions(-)
 create mode 100644 llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions-obj-test.ll

diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
index c7202cc04c26dc..81651c8cb787ab 100644
--- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
+++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
@@ -84,7 +84,7 @@ GlobalVariable *DXContainerGlobals::getFeatureFlags(Module &M) {
   // of the shader flags of all functions in the module. Need to verify
   // and modify the computation of feature flags to be used.
   uint64_t ConsolidatedFeatureFlags = 0;
-  for (const auto &FuncFlags : MSFI.FuncShaderFlagsMap) {
+  for (const auto &FuncFlags : MSFI.FuncShaderFlagsVec) {
     ConsolidatedFeatureFlags |= FuncFlags.second.getFeatureFlags();
   }
 
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index bb350c64b5c505..9afe48667ce8b4 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -13,43 +13,114 @@
 
 #include "DXILShaderFlags.h"
 #include "DirectX.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/IR/DiagnosticInfo.h"
+#include "llvm/IR/DiagnosticPrinter.h"
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Module.h"
 #include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/raw_ostream.h"
 
 using namespace llvm;
 using namespace llvm::dxil;
 
-static void updateFlags(DXILModuleShaderFlagsInfo &MSFI, const Instruction &I) {
-  ComputedShaderFlags &FSF = MSFI.FuncShaderFlagsMap[I.getFunction()];
+namespace {
+/// A simple Wrapper DiagnosticInfo that generates Module-level diagnostic
+/// for ShaderFlagsAnalysis pass
+class DiagnosticInfoShaderFlags : public DiagnosticInfo {
+private:
+  const Twine &Msg;
+  const Module &Mod;
+
+public:
+  /// \p M is the module for which the diagnostic is being emitted. \p Msg is
+  /// the message to show. Note that this class does not copy this message, so
+  /// this reference must be valid for the whole life time of the diagnostic.
+  DiagnosticInfoShaderFlags(const Module &M, const Twine &Msg,
+                            DiagnosticSeverity Severity = DS_Error)
+      : DiagnosticInfo(DK_Unsupported, Severity), Msg(Msg), Mod(M) {}
+
+  void print(DiagnosticPrinter &DP) const override {
+    DP << Mod.getName() << ": " << Msg << '\n';
+  }
+};
+} // namespace
+
+static void updateFlags(ComputedShaderFlags &CSF, const Instruction &I) {
   Type *Ty = I.getType();
-  if (Ty->isDoubleTy()) {
-    FSF.Doubles = true;
+  bool DoubleTyInUse = Ty->isDoubleTy();
+  for (Value *Op : I.operands()) {
+    DoubleTyInUse |= Op->getType()->isDoubleTy();
+  }
+
+  if (DoubleTyInUse) {
+    CSF.Doubles = true;
     switch (I.getOpcode()) {
     case Instruction::FDiv:
     case Instruction::UIToFP:
     case Instruction::SIToFP:
     case Instruction::FPToUI:
     case Instruction::FPToSI:
-      FSF.DX11_1_DoubleExtensions = true;
+      // TODO: To be set if I is a call to DXIL intrinsic DXIL::Opcode::Fma
+      CSF.DX11_1_DoubleExtensions = true;
       break;
     }
   }
 }
 
+static bool compareFuncSFPairs(const FuncShaderFlagsMask &First,
+                               const FuncShaderFlagsMask &Second) {
+  // Construct string representation of the functions in each pair
+  // as "retTypefunctionNamearg1Typearg2Ty..." where the function signature is
+  // retType functionName(arg1Type, arg2Ty,...).  Spaces, braces and commas are
+  //  omitted in the string representation of the signature. This allows
+  // determining a consistent lexicographical order of all functions by their
+  // signatures.
+  std::string FirstFunSig;
+  std::string SecondFunSig;
+  raw_string_ostream FRSO(FirstFunSig);
+  raw_string_ostream SRSO(SecondFunSig);
+
+  // Return type
+  First.first->getReturnType()->print(FRSO);
+  Second.first->getReturnType()->print(SRSO);
+  // Function name
+  FRSO << First.first->getName();
+  SRSO << Second.first->getName();
+  // Argument types
+  for (const Argument &Arg : First.first->args()) {
+    Arg.getType()->print(FRSO);
+  }
+  for (const Argument &Arg : Second.first->args()) {
+    Arg.getType()->print(SRSO);
+  }
+  FRSO.flush();
+  SRSO.flush();
+
+  return FRSO.str().compare(SRSO.str()) < 0;
+}
+
 static DXILModuleShaderFlagsInfo computeFlags(Module &M) {
   DXILModuleShaderFlagsInfo MSFI;
-  for (const auto &F : M) {
+  for (auto &F : M) {
     if (F.isDeclaration())
       continue;
-    if (!MSFI.FuncShaderFlagsMap.contains(&F)) {
-      ComputedShaderFlags CSF{};
-      MSFI.FuncShaderFlagsMap[&F] = CSF;
+    // Each of the functions in a module are unique. Hence no prior shader flags
+    // mask of the function should be present.
+    if (MSFI.hasShaderFlagsMask(&F)) {
+      M.getContext().diagnose(DiagnosticInfoShaderFlags(
+          M, "Shader Flags mask for Function '" + Twine(F.getName()) +
+                 "' already exits"));
     }
+    ComputedShaderFlags CSF{};
     for (const auto &BB : F)
       for (const auto &I : BB)
-        updateFlags(MSFI, I);
+        updateFlags(CSF, I);
+    // Insert shader flag mask for function F
+    MSFI.FuncShaderFlagsVec.push_back({&F, CSF});
   }
+  // Sort MSFI.FuncShaderFlagsVec for later lookup that uses binary search
+  llvm::sort(MSFI.FuncShaderFlagsVec, compareFuncSFPairs);
   return MSFI;
 }
 
@@ -71,6 +142,38 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
   OS << ";\n";
 }
 
+void DXILModuleShaderFlagsInfo::print(raw_ostream &OS) const {
+  OS << "; Shader Flags mask for Module:\n";
+  ModuleFlags.print(OS);
+  for (auto SF : FuncShaderFlagsVec) {
+    OS << "; Shader Flags mask for Function: " << SF.first->getName() << "\n";
+    SF.second.print(OS);
+  }
+}
+
+const ComputedShaderFlags
+DXILModuleShaderFlagsInfo::getShaderFlagsMask(const Function *Func) const {
+  FuncShaderFlagsMask V{Func, {}};
+  auto Iter = llvm::lower_bound(FuncShaderFlagsVec, V, compareFuncSFPairs);
+  if (Iter == FuncShaderFlagsVec.end()) {
+    Func->getContext().diagnose(DiagnosticInfoShaderFlags(
+        *(Func->getParent()), "Shader Flags information of Function '" +
+                                  Twine(Func->getName()) + "' not found"));
+  }
+  if (Iter->first != Func) {
+    Func->getContext().diagnose(DiagnosticInfoShaderFlags(
+        *(Func->getParent()),
+        "Inconsistent Shader Flags information of Function '" +
+            Twine(Func->getName()) + "' retrieved"));
+  }
+  return Iter->second;
+}
+
+bool DXILModuleShaderFlagsInfo::hasShaderFlagsMask(const Function *Func) const {
+  FuncShaderFlagsMask V{Func, {}};
+  return llvm::binary_search(FuncShaderFlagsVec, V);
+}
+
 AnalysisKey ShaderFlagsAnalysis::Key;
 
 DXILModuleShaderFlagsInfo ShaderFlagsAnalysis::run(Module &M,
@@ -86,12 +189,7 @@ bool ShaderFlagsAnalysisWrapper::runOnModule(Module &M) {
 PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
                                                   ModuleAnalysisManager &AM) {
   DXILModuleShaderFlagsInfo Flags = AM.getResult<ShaderFlagsAnalysis>(M);
-  OS << "; Shader Flags mask for Module:\n";
-  Flags.ModuleFlags.print(OS);
-  for (auto SF : Flags.FuncShaderFlagsMap) {
-    OS << "; Shader Flags mash for Function: " << SF.first->getName() << "\n";
-    SF.second.print(OS);
-  }
+  Flags.print(OS);
   return PreservedAnalyses::all();
 }
 
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h
index 6f81ff74384d0c..55967f03ca4de6 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.h
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h
@@ -14,7 +14,6 @@
 #ifndef LLVM_TARGET_DIRECTX_DXILSHADERFLAGS_H
 #define LLVM_TARGET_DIRECTX_DXILSHADERFLAGS_H
 
-#include "llvm/ADT/DenseMap.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/Pass.h"
@@ -66,14 +65,18 @@ struct ComputedShaderFlags {
   LLVM_DUMP_METHOD void dump() const { print(); }
 };
 
-using FunctionShaderFlagsMap =
-    SmallDenseMap<Function const *, ComputedShaderFlags>;
+using FuncShaderFlagsMask = std::pair<Function const *, ComputedShaderFlags>;
+using FunctionShaderFlagsVec = SmallVector<FuncShaderFlagsMask>;
 struct DXILModuleShaderFlagsInfo {
   // Shader Flag mask representing module-level properties
   ComputedShaderFlags ModuleFlags;
-  // Map representing shader flag mask representing properties of each of the
-  // functions in the module
-  FunctionShaderFlagsMap FuncShaderFlagsMap;
+  // Vector of Function-Shader Flag mask pairs representing properties of each
+  // of the functions in the module
+  FunctionShaderFlagsVec FuncShaderFlagsVec;
+
+  const ComputedShaderFlags getShaderFlagsMask(const Function *Func) const;
+  bool hasShaderFlagsMask(const Function *Func) const;
+  void print(raw_ostream &OS = dbgs()) const;
 };
 
 class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> {
diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index 2da4fe83a066c2..f3593325b26415 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -318,7 +318,7 @@ static void translateMetadata(Module &M, const DXILResourceMap &DRM,
     // to be used as shader flags mask value associated with top-level library
     // entry metadata.
     uint64_t ConsolidatedMask = ShaderFlags.ModuleFlags;
-    for (const auto &FunFlags : ShaderFlags.FuncShaderFlagsMap) {
+    for (const auto &FunFlags : ShaderFlags.FuncShaderFlagsVec) {
       ConsolidatedMask |= FunFlags.second;
     }
     EntryFnMDNodes.emplace_back(
@@ -329,20 +329,15 @@ static void translateMetadata(Module &M, const DXILResourceMap &DRM,
   }
 
   for (const EntryProperties &EntryProp : MMDI.EntryPropertyVec) {
-    auto FSFIt = ShaderFlags.FuncShaderFlagsMap.find(EntryProp.Entry);
-    if (FSFIt == ShaderFlags.FuncShaderFlagsMap.end()) {
-      M.getContext().diagnose(DiagnosticInfoTranslateMD(
-          M, "Shader Flags of Function '" + Twine(EntryProp.Entry->getName()) +
-                 "' not found"));
-    }
+    ComputedShaderFlags ECSF = ShaderFlags.getShaderFlagsMask(EntryProp.Entry);
     // If ShaderProfile is Library, mask is already consolidated in the
     // top-level library node. Hence it is not emitted.
     uint64_t EntryShaderFlags = 0;
     if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
       // TODO: Create a consolidated shader flag mask of all the entry
       // functions and its callees. The following is correct only if
-      // (*FSIt).first has no call instructions.
-      EntryShaderFlags = (*FSFIt).second | ShaderFlags.ModuleFlags;
+      // EntryProp.Entry has no call instructions.
+      EntryShaderFlags = ECSF | ShaderFlags.ModuleFlags;
     }
     if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
       if (EntryProp.ShaderStage != MMDI.ShaderProfile) {
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions-obj-test.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions-obj-test.ll
new file mode 100644
index 00000000000000..2b6b39a9c2d37e
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions-obj-test.ll
@@ -0,0 +1,19 @@
+; RUN: llc %s --filetype=obj -o - | obj2yaml | FileCheck %s --check-prefix=DXC
+
+target triple = "dxil-pc-shadermodel6.7-library"
+define double @div(double %a, double %b) #0 {
+  %res = fdiv double %a, %b
+  ret double %res
+}
+
+attributes #0 = { convergent norecurse nounwind "hlsl.export"}
+
+; DXC: - Name:            SFI0
+; DXC-NEXT:     Size:            8
+; DXC-NEXT:     Flags:
+; DXC-NEXT:       Doubles:         true
+; DXC-NOT:   {{[A-Za-z]+: +true}}
+; DXC:            DX11_1_DoubleExtensions:         true
+; DXC-NOT:   {{[A-Za-z]+: +true}}
+; DXC:       NextUnusedBit:   false
+; DXC: ...
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
index a8d5f9c78f0b43..7627e160514436 100644
--- a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
@@ -1,27 +1,74 @@
 ; RUN: opt -S --passes="print-dx-shader-flags" 2>&1 %s | FileCheck %s
-; RUN: llc %s --filetype=obj -o - | obj2yaml | FileCheck %s --check-prefix=DXC
 
 target triple = "dxil-pc-shadermodel6.7-library"
 
-; CHECK: ; Shader Flags Value: 0x00000044
-; CHECK: ; Note: shader requires additional functionality:
+; CHECK: ; Shader Flags mask for Module:
+; CHECK-NEXT: ; Shader Flags Value: 0x00000000
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Shader Flags mask for Function: test_fdiv_double
+; CHECK-NEXT: ; Shader Flags Value: 0x00000044
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Note: shader requires additional functionality:
 ; CHECK-NEXT: ;       Double-precision floating point
 ; CHECK-NEXT: ;       Double-precision extensions for 11.1
 ; CHECK-NEXT: ; Note: extra DXIL module flags:
-; CHECK-NEXT: {{^;$}}
-define double @div(double %a, double %b) #0 {
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Shader Flags mask for Function: test_sitofp_i64
+; CHECK-NEXT: ; Shader Flags Value: 0x00000044
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Note: shader requires additional functionality:
+; CHECK-NEXT: ;       Double-precision floating point
+; CHECK-NEXT: ;       Double-precision extensions for 11.1
+; CHECK-NEXT: ; Note: extra DXIL module flags:
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Shader Flags mask for Function: test_uitofp_i64
+; CHECK-NEXT: ; Shader Flags Value: 0x00000044
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Note: shader requires additional functionality:
+; CHECK-NEXT: ;       Double-precision floating point
+; CHECK-NEXT: ;       Double-precision extensions for 11.1
+; CHECK-NEXT: ; Note: extra DXIL module flags:
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Shader Flags mask for Function: test_fptoui_i32
+; CHECK-NEXT: ; Shader Flags Value: 0x00000044
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Note: shader requires additional functionality:
+; CHECK-NEXT: ;       Double-precision floating point
+; CHECK-NEXT: ;       Double-precision extensions for 11.1
+; CHECK-NEXT: ; Note: extra DXIL module flags:
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Shader Flags mask for Function: test_fptosi_i64
+; CHECK-NEXT: ; Shader Flags Value: 0x00000044
+; CHECK-NEXT: ;
+; CHECK-NEXT: ; Note: shader requires additional functionality:
+; CHECK-NEXT: ;       Double-precision floating point
+; CHECK-NEXT: ;       Double-precision extensions for 11.1
+; CHECK-NEXT: ; Note: extra DXIL module flags:
+; CHECK-NEXT: ;
+
+define double @test_fdiv_double(double %a, double %b) #0 {
   %res = fdiv double %a, %b
   ret double %res
 }
 
-attributes #0 = { convergent norecurse nounwind "hlsl.export"}
+define double @test_uitofp_i64(i64 %a) #0 {
+  %r = uitofp i64 %a to double
+  ret double %r
+}
+
+define double @test_sitofp_i64(i64 %a) #0 {
+  %r = sitofp i64 %a to double
+  ret double %r
+}
 
-; DXC: - Name:            SFI0
-; DXC-NEXT:     Size:            8
-; DXC-NEXT:     Flags:
-; DXC-NEXT:       Doubles:         true
-; DXC-NOT:   {{[A-Za-z]+: +true}}
-; DXC:            DX11_1_DoubleExtensions:         true
-; DXC-NOT:   {{[A-Za-z]+: +true}}
-; DXC:       NextUnusedBit:   false
-; DXC: ...
+define i32 @test_fptoui_i32(double %a) #0 {
+  %r = fptoui double %a to i32
+  ret i32 %r
+}
+
+define i64 @test_fptosi_i64(double %a) #0 {
+  %r = fptosi double %a to i64
+  ret i64 %r
+}
+
+attributes #0 = { convergent norecurse nounwind "hlsl.export"}
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/doubles.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/doubles.ll
index e9b44240e10b9b..b4276c2144af97 100644
--- a/llvm/test/CodeGen/DirectX/ShaderFlags/doubles.ll
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/doubles.ll
@@ -3,7 +3,10 @@
 
 target triple = "dxil-pc-shadermodel6.7-library"
 
-; CHECK: ; Shader Flags Value: 0x00000004
+; CHECK: ; Shader Flags mask for Module:
+; CHECK-NEXT: ; Shader Flags Value: 0x00000000
+; CHECK: ; Shader Flags mask for Function: add
+; CHECK-NEXT: ; Shader Flags Value: 0x00000004
 ; CHECK: ; Note: shader requires additional functionality:
 ; CHECK-NEXT: ;       Double-precision floating point
 ; CHECK-NEXT: ; Note: extra DXIL module flags:
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/no_flags.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/no_flags.ll
index f7baa1b64f9cd1..df40fe9d86eda1 100644
--- a/llvm/test/CodeGen/DirectX/ShaderFlags/no_flags.ll
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/no_flags.ll
@@ -2,7 +2,11 @@
 
 target triple = "dxil-pc-shadermodel6.7-library"
 
-; CHECK: ; Shader Flags Value: 0x00000000
+; CHECK: ; Shader Flags mask for Module:
+; CHECK-NEXT: ; Shader Flags Value: 0x00000000
+;
+; CHECK: ; Shader Flags mask for Function: add
+; CHECK-NEXT: ; Shader Flags Value: 0x00000000
 define i32 @add(i32 %a, i32 %b) {
   %sum = add i32 %a, %b
   ret i32 %sum

>From c02053ad2b7d2a489fa04af14b02f6f2b11818f1 Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Mon, 28 Oct 2024 16:36:31 -0400
Subject: [PATCH 04/12] Compare functions by their names instead of
 constructing a pseudo-signature Non-empty Function names are unique in LLVM
 IR.

Update the expected test output accordingly
---
 llvm/lib/Target/DirectX/DXILShaderFlags.cpp   | 35 +++----------------
 .../DirectX/ShaderFlags/double-extensions.ll  |  8 ++---
 2 files changed, 8 insertions(+), 35 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index 9afe48667ce8b4..4fb008b11a0a2e 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -70,34 +70,7 @@ static void updateFlags(ComputedShaderFlags &CSF, const Instruction &I) {
 
 static bool compareFuncSFPairs(const FuncShaderFlagsMask &First,
                                const FuncShaderFlagsMask &Second) {
-  // Construct string representation of the functions in each pair
-  // as "retTypefunctionNamearg1Typearg2Ty..." where the function signature is
-  // retType functionName(arg1Type, arg2Ty,...).  Spaces, braces and commas are
-  //  omitted in the string representation of the signature. This allows
-  // determining a consistent lexicographical order of all functions by their
-  // signatures.
-  std::string FirstFunSig;
-  std::string SecondFunSig;
-  raw_string_ostream FRSO(FirstFunSig);
-  raw_string_ostream SRSO(SecondFunSig);
-
-  // Return type
-  First.first->getReturnType()->print(FRSO);
-  Second.first->getReturnType()->print(SRSO);
-  // Function name
-  FRSO << First.first->getName();
-  SRSO << Second.first->getName();
-  // Argument types
-  for (const Argument &Arg : First.first->args()) {
-    Arg.getType()->print(FRSO);
-  }
-  for (const Argument &Arg : Second.first->args()) {
-    Arg.getType()->print(SRSO);
-  }
-  FRSO.flush();
-  SRSO.flush();
-
-  return FRSO.str().compare(SRSO.str()) < 0;
+  return (First.first->getName().compare(Second.first->getName()) < 0);
 }
 
 static DXILModuleShaderFlagsInfo computeFlags(Module &M) {
@@ -108,9 +81,9 @@ static DXILModuleShaderFlagsInfo computeFlags(Module &M) {
     // Each of the functions in a module are unique. Hence no prior shader flags
     // mask of the function should be present.
     if (MSFI.hasShaderFlagsMask(&F)) {
-      M.getContext().diagnose(DiagnosticInfoShaderFlags(
-          M, "Shader Flags mask for Function '" + Twine(F.getName()) +
-                 "' already exits"));
+      M.getContext().diagnose(
+          DiagnosticInfoShaderFlags(M, "Shader Flags mask for Function '" +
+                                           F.getName() + "' already exists"));
     }
     ComputedShaderFlags CSF{};
     for (const auto &BB : F)
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
index 7627e160514436..dc4a90194262a0 100644
--- a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
@@ -13,7 +13,7 @@ target triple = "dxil-pc-shadermodel6.7-library"
 ; CHECK-NEXT: ;       Double-precision extensions for 11.1
 ; CHECK-NEXT: ; Note: extra DXIL module flags:
 ; CHECK-NEXT: ;
-; CHECK-NEXT: ; Shader Flags mask for Function: test_sitofp_i64
+; CHECK-NEXT: ; Shader Flags mask for Function: test_fptosi_i64
 ; CHECK-NEXT: ; Shader Flags Value: 0x00000044
 ; CHECK-NEXT: ;
 ; CHECK-NEXT: ; Note: shader requires additional functionality:
@@ -21,7 +21,7 @@ target triple = "dxil-pc-shadermodel6.7-library"
 ; CHECK-NEXT: ;       Double-precision extensions for 11.1
 ; CHECK-NEXT: ; Note: extra DXIL module flags:
 ; CHECK-NEXT: ;
-; CHECK-NEXT: ; Shader Flags mask for Function: test_uitofp_i64
+; CHECK-NEXT: ; Shader Flags mask for Function: test_fptoui_i32
 ; CHECK-NEXT: ; Shader Flags Value: 0x00000044
 ; CHECK-NEXT: ;
 ; CHECK-NEXT: ; Note: shader requires additional functionality:
@@ -29,7 +29,7 @@ target triple = "dxil-pc-shadermodel6.7-library"
 ; CHECK-NEXT: ;       Double-precision extensions for 11.1
 ; CHECK-NEXT: ; Note: extra DXIL module flags:
 ; CHECK-NEXT: ;
-; CHECK-NEXT: ; Shader Flags mask for Function: test_fptoui_i32
+; CHECK-NEXT: ; Shader Flags mask for Function: test_sitofp_i64
 ; CHECK-NEXT: ; Shader Flags Value: 0x00000044
 ; CHECK-NEXT: ;
 ; CHECK-NEXT: ; Note: shader requires additional functionality:
@@ -37,7 +37,7 @@ target triple = "dxil-pc-shadermodel6.7-library"
 ; CHECK-NEXT: ;       Double-precision extensions for 11.1
 ; CHECK-NEXT: ; Note: extra DXIL module flags:
 ; CHECK-NEXT: ;
-; CHECK-NEXT: ; Shader Flags mask for Function: test_fptosi_i64
+; CHECK-NEXT: ; Shader Flags mask for Function: test_uitofp_i64
 ; CHECK-NEXT: ; Shader Flags Value: 0x00000044
 ; CHECK-NEXT: ;
 ; CHECK-NEXT: ; Note: shader requires additional functionality:

>From 47ab4c5ffd71d0b08314af628a52191adcd0c438 Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Mon, 28 Oct 2024 21:51:46 -0400
Subject: [PATCH 05/12] Fix incorrect lookup that expects a sorted vector done
 in a possibly unsorted vector Other changes based on latest PR feedback

---
 llvm/lib/Target/DirectX/DXILShaderFlags.cpp | 35 ++++++++++-----------
 1 file changed, 17 insertions(+), 18 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index 4fb008b11a0a2e..e23ed87a0420a8 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -68,32 +68,37 @@ static void updateFlags(ComputedShaderFlags &CSF, const Instruction &I) {
   }
 }
 
+static bool compareFunctions(Function const *F1, Function const *F2) {
+  return (F1->getName().compare(F2->getName()) < 0);
+}
+
 static bool compareFuncSFPairs(const FuncShaderFlagsMask &First,
                                const FuncShaderFlagsMask &Second) {
-  return (First.first->getName().compare(Second.first->getName()) < 0);
+  return compareFunctions(First.first, Second.first);
 }
 
 static DXILModuleShaderFlagsInfo computeFlags(Module &M) {
   DXILModuleShaderFlagsInfo MSFI;
+  // Create a sorted list of functions in the module
+  SmallVector<Function const *> FuncList;
   for (auto &F : M) {
     if (F.isDeclaration())
       continue;
-    // Each of the functions in a module are unique. Hence no prior shader flags
-    // mask of the function should be present.
-    if (MSFI.hasShaderFlagsMask(&F)) {
-      M.getContext().diagnose(
-          DiagnosticInfoShaderFlags(M, "Shader Flags mask for Function '" +
-                                           F.getName() + "' already exists"));
-    }
+    FuncList.push_back(&F);
+  }
+  llvm::sort(FuncList, compareFunctions);
+
+  MSFI.FuncShaderFlagsVec.clear();
+
+  // Collect shader flags for each of the functions
+  for (auto F : FuncList) {
     ComputedShaderFlags CSF{};
-    for (const auto &BB : F)
+    for (const auto &BB : *F)
       for (const auto &I : BB)
         updateFlags(CSF, I);
     // Insert shader flag mask for function F
-    MSFI.FuncShaderFlagsVec.push_back({&F, CSF});
+    MSFI.FuncShaderFlagsVec.push_back({F, CSF});
   }
-  // Sort MSFI.FuncShaderFlagsVec for later lookup that uses binary search
-  llvm::sort(MSFI.FuncShaderFlagsVec, compareFuncSFPairs);
   return MSFI;
 }
 
@@ -133,12 +138,6 @@ DXILModuleShaderFlagsInfo::getShaderFlagsMask(const Function *Func) const {
         *(Func->getParent()), "Shader Flags information of Function '" +
                                   Twine(Func->getName()) + "' not found"));
   }
-  if (Iter->first != Func) {
-    Func->getContext().diagnose(DiagnosticInfoShaderFlags(
-        *(Func->getParent()),
-        "Inconsistent Shader Flags information of Function '" +
-            Twine(Func->getName()) + "' retrieved"));
-  }
   return Iter->second;
 }
 

>From fa8ec60d67cd827e2d365cdc5b0af3579eeb9193 Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Tue, 29 Oct 2024 11:01:09 -0400
Subject: [PATCH 06/12] Use getFunctionList() instead of iterating Module for
 functions.

---
 llvm/lib/Target/DirectX/DXILShaderFlags.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index e23ed87a0420a8..c572ca51091b89 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -77,11 +77,11 @@ static bool compareFuncSFPairs(const FuncShaderFlagsMask &First,
   return compareFunctions(First.first, Second.first);
 }
 
-static DXILModuleShaderFlagsInfo computeFlags(Module &M) {
+static DXILModuleShaderFlagsInfo computeFlags(const Module &M) {
   DXILModuleShaderFlagsInfo MSFI;
   // Create a sorted list of functions in the module
   SmallVector<Function const *> FuncList;
-  for (auto &F : M) {
+  for (const auto &F : M.getFunctionList()) {
     if (F.isDeclaration())
       continue;
     FuncList.push_back(&F);
@@ -91,7 +91,7 @@ static DXILModuleShaderFlagsInfo computeFlags(Module &M) {
   MSFI.FuncShaderFlagsVec.clear();
 
   // Collect shader flags for each of the functions
-  for (auto F : FuncList) {
+  for (const auto &F : FuncList) {
     ComputedShaderFlags CSF{};
     for (const auto &BB : *F)
       for (const auto &I : BB)

>From a4f1e518ab09fffb62e6ed3254d747e04607eec4 Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Tue, 29 Oct 2024 12:09:44 -0400
Subject: [PATCH 07/12] Change return type of getShaderFlagsMask() to
 Expected<T> Delete unused class DiagnosticInfoShaderFlags

---
 llvm/lib/Target/DirectX/DXILShaderFlags.cpp   | 37 ++-----------------
 llvm/lib/Target/DirectX/DXILShaderFlags.h     |  3 +-
 .../Target/DirectX/DXILTranslateMetadata.cpp  | 11 +++++-
 3 files changed, 15 insertions(+), 36 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index c572ca51091b89..c2e202858b2a0b 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -14,38 +14,15 @@
 #include "DXILShaderFlags.h"
 #include "DirectX.h"
 #include "llvm/ADT/STLExtras.h"
-#include "llvm/IR/DiagnosticInfo.h"
-#include "llvm/IR/DiagnosticPrinter.h"
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Module.h"
+#include "llvm/Support/Error.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/raw_ostream.h"
 
 using namespace llvm;
 using namespace llvm::dxil;
 
-namespace {
-/// A simple Wrapper DiagnosticInfo that generates Module-level diagnostic
-/// for ShaderFlagsAnalysis pass
-class DiagnosticInfoShaderFlags : public DiagnosticInfo {
-private:
-  const Twine &Msg;
-  const Module &Mod;
-
-public:
-  /// \p M is the module for which the diagnostic is being emitted. \p Msg is
-  /// the message to show. Note that this class does not copy this message, so
-  /// this reference must be valid for the whole life time of the diagnostic.
-  DiagnosticInfoShaderFlags(const Module &M, const Twine &Msg,
-                            DiagnosticSeverity Severity = DS_Error)
-      : DiagnosticInfo(DK_Unsupported, Severity), Msg(Msg), Mod(M) {}
-
-  void print(DiagnosticPrinter &DP) const override {
-    DP << Mod.getName() << ": " << Msg << '\n';
-  }
-};
-} // namespace
-
 static void updateFlags(ComputedShaderFlags &CSF, const Instruction &I) {
   Type *Ty = I.getType();
   bool DoubleTyInUse = Ty->isDoubleTy();
@@ -129,23 +106,17 @@ void DXILModuleShaderFlagsInfo::print(raw_ostream &OS) const {
   }
 }
 
-const ComputedShaderFlags
+Expected<const ComputedShaderFlags &>
 DXILModuleShaderFlagsInfo::getShaderFlagsMask(const Function *Func) const {
   FuncShaderFlagsMask V{Func, {}};
   auto Iter = llvm::lower_bound(FuncShaderFlagsVec, V, compareFuncSFPairs);
   if (Iter == FuncShaderFlagsVec.end()) {
-    Func->getContext().diagnose(DiagnosticInfoShaderFlags(
-        *(Func->getParent()), "Shader Flags information of Function '" +
-                                  Twine(Func->getName()) + "' not found"));
+    return createStringError("Shader Flags information of Function '" +
+                             Twine(Func->getName()) + "' not found");
   }
   return Iter->second;
 }
 
-bool DXILModuleShaderFlagsInfo::hasShaderFlagsMask(const Function *Func) const {
-  FuncShaderFlagsMask V{Func, {}};
-  return llvm::binary_search(FuncShaderFlagsVec, V);
-}
-
 AnalysisKey ShaderFlagsAnalysis::Key;
 
 DXILModuleShaderFlagsInfo ShaderFlagsAnalysis::run(Module &M,
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h
index 55967f03ca4de6..7a1654ae739330 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.h
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h
@@ -74,7 +74,8 @@ struct DXILModuleShaderFlagsInfo {
   // of the functions in the module
   FunctionShaderFlagsVec FuncShaderFlagsVec;
 
-  const ComputedShaderFlags getShaderFlagsMask(const Function *Func) const;
+  Expected<const ComputedShaderFlags &>
+  getShaderFlagsMask(const Function *Func) const;
   bool hasShaderFlagsMask(const Function *Func) const;
   void print(raw_ostream &OS = dbgs()) const;
 };
diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index f3593325b26415..dbcbc724188ae5 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -329,7 +329,14 @@ static void translateMetadata(Module &M, const DXILResourceMap &DRM,
   }
 
   for (const EntryProperties &EntryProp : MMDI.EntryPropertyVec) {
-    ComputedShaderFlags ECSF = ShaderFlags.getShaderFlagsMask(EntryProp.Entry);
+    Expected<const ComputedShaderFlags &> ECSF =
+        ShaderFlags.getShaderFlagsMask(EntryProp.Entry);
+    if (Error E = ECSF.takeError()) {
+      M.getContext().diagnose(DiagnosticInfoTranslateMD(
+          M, "Shader Flags information of Function '" +
+                 Twine(EntryProp.Entry->getName()) + "' not found"));
+    }
+
     // If ShaderProfile is Library, mask is already consolidated in the
     // top-level library node. Hence it is not emitted.
     uint64_t EntryShaderFlags = 0;
@@ -337,7 +344,7 @@ static void translateMetadata(Module &M, const DXILResourceMap &DRM,
       // TODO: Create a consolidated shader flag mask of all the entry
       // functions and its callees. The following is correct only if
       // EntryProp.Entry has no call instructions.
-      EntryShaderFlags = ECSF | ShaderFlags.ModuleFlags;
+      EntryShaderFlags = *ECSF | ShaderFlags.ModuleFlags;
     }
     if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
       if (EntryProp.ShaderStage != MMDI.ShaderProfile) {

>From a6d84b225aedc32f4ca732f61166de96bf6dbe45 Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Wed, 30 Oct 2024 12:22:29 -0400
Subject: [PATCH 08/12] Address PR feedback - Use CSF.Doubles directly - Remove
 type-aliasies FuncShaderFlags* - Make ModuleFlags and FunctionFlags private -
 Delete DXILModuleShaderFlagsInfo::print() - Delete check prefix DXC from test
 with a single run - Get rid of compare functions - Change order of expected
 output accordingly in double-extensions.ll - Add extra comments for
 clarification - Add back DiagnosticInfoShaderFlags - Additional error checks

---
 .../lib/Target/DirectX/DXContainerGlobals.cpp |   2 +-
 llvm/lib/Target/DirectX/DXILShaderFlags.cpp   | 123 +++++++++++++-----
 llvm/lib/Target/DirectX/DXILShaderFlags.h     |  32 +++--
 .../Target/DirectX/DXILTranslateMetadata.cpp  |  12 +-
 .../ShaderFlags/double-extensions-obj-test.ll |  20 +--
 .../DirectX/ShaderFlags/double-extensions.ll  |   8 +-
 6 files changed, 134 insertions(+), 63 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
index 81651c8cb787ab..ef2333f01f752e 100644
--- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
+++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
@@ -84,7 +84,7 @@ GlobalVariable *DXContainerGlobals::getFeatureFlags(Module &M) {
   // of the shader flags of all functions in the module. Need to verify
   // and modify the computation of feature flags to be used.
   uint64_t ConsolidatedFeatureFlags = 0;
-  for (const auto &FuncFlags : MSFI.FuncShaderFlagsVec) {
+  for (const auto &FuncFlags : MSFI.getFunctionFlags()) {
     ConsolidatedFeatureFlags |= FuncFlags.second.getFeatureFlags();
   }
 
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index c2e202858b2a0b..138b2a4b469dc2 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -14,6 +14,8 @@
 #include "DXILShaderFlags.h"
 #include "DirectX.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/IR/DiagnosticInfo.h"
+#include "llvm/IR/DiagnosticPrinter.h"
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Module.h"
 #include "llvm/Support/Error.h"
@@ -23,15 +25,38 @@
 using namespace llvm;
 using namespace llvm::dxil;
 
-static void updateFlags(ComputedShaderFlags &CSF, const Instruction &I) {
-  Type *Ty = I.getType();
-  bool DoubleTyInUse = Ty->isDoubleTy();
-  for (Value *Op : I.operands()) {
-    DoubleTyInUse |= Op->getType()->isDoubleTy();
+namespace {
+/// A simple Wrapper DiagnosticInfo that generates Module-level diagnostic
+/// for Shader Flags Analysis pass
+class DiagnosticInfoShaderFlags : public DiagnosticInfo {
+private:
+  const Twine &Msg;
+  const Module &Mod;
+
+public:
+  /// \p M is the module for which the diagnostic is being emitted. \p Msg is
+  /// the message to show. Note that this class does not copy this message, so
+  /// this reference must be valid for the whole life time of the diagnostic.
+  DiagnosticInfoShaderFlags(const Module &M, const Twine &Msg,
+                            DiagnosticSeverity Severity = DS_Error)
+      : DiagnosticInfo(DK_Unsupported, Severity), Msg(Msg), Mod(M) {}
+
+  void print(DiagnosticPrinter &DP) const override {
+    DP << Mod.getName() << ": " << Msg << '\n';
   }
+};
+} // namespace
 
-  if (DoubleTyInUse) {
-    CSF.Doubles = true;
+static void updateFlags(ComputedShaderFlags &CSF, const Instruction &I) {
+  if (!CSF.Doubles) {
+    CSF.Doubles = I.getType()->isDoubleTy();
+  }
+  if (!CSF.Doubles) {
+    for (Value *Op : I.operands()) {
+      CSF.Doubles |= Op->getType()->isDoubleTy();
+    }
+  }
+  if (CSF.Doubles) {
     switch (I.getOpcode()) {
     case Instruction::FDiv:
     case Instruction::UIToFP:
@@ -45,27 +70,20 @@ static void updateFlags(ComputedShaderFlags &CSF, const Instruction &I) {
   }
 }
 
-static bool compareFunctions(Function const *F1, Function const *F2) {
-  return (F1->getName().compare(F2->getName()) < 0);
-}
-
-static bool compareFuncSFPairs(const FuncShaderFlagsMask &First,
-                               const FuncShaderFlagsMask &Second) {
-  return compareFunctions(First.first, Second.first);
-}
-
 static DXILModuleShaderFlagsInfo computeFlags(const Module &M) {
   DXILModuleShaderFlagsInfo MSFI;
-  // Create a sorted list of functions in the module
+  // Construct a sorted list of functions in the module. Walk the sorted list to
+  // create a list of <Function, Shader Flags Mask> pairs. This list is thus
+  // sorted at construction time and may be looked up using binary search.
   SmallVector<Function const *> FuncList;
   for (const auto &F : M.getFunctionList()) {
     if (F.isDeclaration())
       continue;
     FuncList.push_back(&F);
   }
-  llvm::sort(FuncList, compareFunctions);
+  llvm::sort(FuncList);
 
-  MSFI.FuncShaderFlagsVec.clear();
+  MSFI.clear();
 
   // Collect shader flags for each of the functions
   for (const auto &F : FuncList) {
@@ -74,7 +92,10 @@ static DXILModuleShaderFlagsInfo computeFlags(const Module &M) {
       for (const auto &I : BB)
         updateFlags(CSF, I);
     // Insert shader flag mask for function F
-    MSFI.FuncShaderFlagsVec.push_back({F, CSF});
+    if (!MSFI.insertInorderFunctionFlags(F, CSF)) {
+      M.getContext().diagnose(DiagnosticInfoShaderFlags(
+          M, "Failed to add shader flags mask for function" + F->getName()));
+    }
   }
   return MSFI;
 }
@@ -97,22 +118,45 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
   OS << ";\n";
 }
 
-void DXILModuleShaderFlagsInfo::print(raw_ostream &OS) const {
-  OS << "; Shader Flags mask for Module:\n";
-  ModuleFlags.print(OS);
-  for (auto SF : FuncShaderFlagsVec) {
-    OS << "; Shader Flags mask for Function: " << SF.first->getName() << "\n";
-    SF.second.print(OS);
-  }
+void DXILModuleShaderFlagsInfo::clear() {
+  ModuleFlags = ComputedShaderFlags{};
+  FunctionFlags.clear();
+}
+
+/// Insert the pair <Func, FlagMask> into the sorted vector
+/// FunctionFlags. The insertion is expected to be in-order and hence
+/// is done at the end of the already sorted list.
+[[nodiscard]] bool DXILModuleShaderFlagsInfo::insertInorderFunctionFlags(
+    const Function *Func, ComputedShaderFlags FlagMask) {
+  std::pair<Function const *, ComputedShaderFlags> V{Func, {}};
+  auto Iter = llvm::lower_bound(FunctionFlags, V);
+  if (Iter != FunctionFlags.end())
+    return false;
+
+  FunctionFlags.push_back({Func, FlagMask});
+  return true;
+}
+
+SmallVector<std::pair<Function const *, ComputedShaderFlags>>
+DXILModuleShaderFlagsInfo::getFunctionFlags() const {
+  return FunctionFlags;
+}
+
+ComputedShaderFlags DXILModuleShaderFlagsInfo::getModuleFlags() const {
+  return ModuleFlags;
 }
 
 Expected<const ComputedShaderFlags &>
 DXILModuleShaderFlagsInfo::getShaderFlagsMask(const Function *Func) const {
-  FuncShaderFlagsMask V{Func, {}};
-  auto Iter = llvm::lower_bound(FuncShaderFlagsVec, V, compareFuncSFPairs);
-  if (Iter == FuncShaderFlagsVec.end()) {
+  std::pair<Function const *, ComputedShaderFlags> V{Func, {}};
+  // It is correct to delegate comparison of two pairs, say P1, P2, to default
+  // operator< for pairs that returns the evaluation of (P1.first < P2.first)
+  // viz., comparison of Function pointers - the same comparison criterion used
+  // for sorting module functions walked to form FunctionFLags vector..
+  auto Iter = llvm::lower_bound(FunctionFlags, V);
+  if (Iter == FunctionFlags.end()) {
     return createStringError("Shader Flags information of Function '" +
-                             Twine(Func->getName()) + "' not found");
+                             Func->getName() + "' not found");
   }
   return Iter->second;
 }
@@ -131,8 +175,21 @@ bool ShaderFlagsAnalysisWrapper::runOnModule(Module &M) {
 
 PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
                                                   ModuleAnalysisManager &AM) {
-  DXILModuleShaderFlagsInfo Flags = AM.getResult<ShaderFlagsAnalysis>(M);
-  Flags.print(OS);
+  DXILModuleShaderFlagsInfo FlagsInfo = AM.getResult<ShaderFlagsAnalysis>(M);
+  OS << "; Shader Flags mask for Module:\n";
+  FlagsInfo.getModuleFlags().print(OS);
+  for (const auto &F : M.getFunctionList()) {
+    if (F.isDeclaration())
+      continue;
+    OS << "; Shader Flags mask for Function: " << F.getName() << "\n";
+    auto SFMask = FlagsInfo.getShaderFlagsMask(&F);
+    if (Error E = SFMask.takeError()) {
+      M.getContext().diagnose(
+          DiagnosticInfoShaderFlags(M, toString(std::move(E))));
+    }
+    SFMask->print(OS);
+  }
+
   return PreservedAnalyses::all();
 }
 
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h
index 7a1654ae739330..dc3451728b8d18 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.h
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h
@@ -61,23 +61,37 @@ struct ComputedShaderFlags {
     return FeatureFlags;
   }
 
+  uint64_t getModuleFlags() const {
+    uint64_t ModuleFlags = 0;
+#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str)                         \
+  ModuleFlags |= FlagName ? getMask(DxilModuleBit) : 0ull;
+#include "llvm/BinaryFormat/DXContainerConstants.def"
+    return ModuleFlags;
+  }
+
   void print(raw_ostream &OS = dbgs()) const;
   LLVM_DUMP_METHOD void dump() const { print(); }
 };
 
-using FuncShaderFlagsMask = std::pair<Function const *, ComputedShaderFlags>;
-using FunctionShaderFlagsVec = SmallVector<FuncShaderFlagsMask>;
 struct DXILModuleShaderFlagsInfo {
-  // Shader Flag mask representing module-level properties
-  ComputedShaderFlags ModuleFlags;
-  // Vector of Function-Shader Flag mask pairs representing properties of each
-  // of the functions in the module
-  FunctionShaderFlagsVec FuncShaderFlagsVec;
-
   Expected<const ComputedShaderFlags &>
   getShaderFlagsMask(const Function *Func) const;
   bool hasShaderFlagsMask(const Function *Func) const;
-  void print(raw_ostream &OS = dbgs()) const;
+  void clear();
+  ComputedShaderFlags getModuleFlags() const;
+  SmallVector<std::pair<Function const *, ComputedShaderFlags>>
+  getFunctionFlags() const;
+  [[nodiscard]] bool insertInorderFunctionFlags(const Function *,
+                                                ComputedShaderFlags);
+
+private:
+  // Shader Flag mask representing module-level properties. These are
+  // represented using the macro DXIL_MODULE_FLAG
+  ComputedShaderFlags ModuleFlags;
+  // Vector of Function-Shader Flag mask pairs representing properties of each
+  // of the functions in the module. Shader Flags of each function are those
+  // represented using the macro SHADER_FEATURE_FLAG.
+  SmallVector<std::pair<Function const *, ComputedShaderFlags>> FunctionFlags;
 };
 
 class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> {
diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index dbcbc724188ae5..069469c66b2d97 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -25,6 +25,7 @@
 #include "llvm/IR/Module.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
+#include "llvm/Support/Error.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/VersionTuple.h"
 #include "llvm/TargetParser/Triple.h"
@@ -317,8 +318,8 @@ static void translateMetadata(Module &M, const DXILResourceMap &DRM,
     // Create a consolidated shader flag mask of all functions in the library
     // to be used as shader flags mask value associated with top-level library
     // entry metadata.
-    uint64_t ConsolidatedMask = ShaderFlags.ModuleFlags;
-    for (const auto &FunFlags : ShaderFlags.FuncShaderFlagsVec) {
+    uint64_t ConsolidatedMask = ShaderFlags.getModuleFlags();
+    for (const auto &FunFlags : ShaderFlags.getFunctionFlags()) {
       ConsolidatedMask |= FunFlags.second;
     }
     EntryFnMDNodes.emplace_back(
@@ -332,9 +333,8 @@ static void translateMetadata(Module &M, const DXILResourceMap &DRM,
     Expected<const ComputedShaderFlags &> ECSF =
         ShaderFlags.getShaderFlagsMask(EntryProp.Entry);
     if (Error E = ECSF.takeError()) {
-      M.getContext().diagnose(DiagnosticInfoTranslateMD(
-          M, "Shader Flags information of Function '" +
-                 Twine(EntryProp.Entry->getName()) + "' not found"));
+      M.getContext().diagnose(
+          DiagnosticInfoTranslateMD(M, toString(std::move(E))));
     }
 
     // If ShaderProfile is Library, mask is already consolidated in the
@@ -344,7 +344,7 @@ static void translateMetadata(Module &M, const DXILResourceMap &DRM,
       // TODO: Create a consolidated shader flag mask of all the entry
       // functions and its callees. The following is correct only if
       // EntryProp.Entry has no call instructions.
-      EntryShaderFlags = *ECSF | ShaderFlags.ModuleFlags;
+      EntryShaderFlags = *ECSF | ShaderFlags.getModuleFlags();
     }
     if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
       if (EntryProp.ShaderStage != MMDI.ShaderProfile) {
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions-obj-test.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions-obj-test.ll
index 2b6b39a9c2d37e..f920bf004c8da3 100644
--- a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions-obj-test.ll
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions-obj-test.ll
@@ -1,4 +1,4 @@
-; RUN: llc %s --filetype=obj -o - | obj2yaml | FileCheck %s --check-prefix=DXC
+; RUN: llc %s --filetype=obj -o - | obj2yaml | FileCheck %s
 
 target triple = "dxil-pc-shadermodel6.7-library"
 define double @div(double %a, double %b) #0 {
@@ -8,12 +8,12 @@ define double @div(double %a, double %b) #0 {
 
 attributes #0 = { convergent norecurse nounwind "hlsl.export"}
 
-; DXC: - Name:            SFI0
-; DXC-NEXT:     Size:            8
-; DXC-NEXT:     Flags:
-; DXC-NEXT:       Doubles:         true
-; DXC-NOT:   {{[A-Za-z]+: +true}}
-; DXC:            DX11_1_DoubleExtensions:         true
-; DXC-NOT:   {{[A-Za-z]+: +true}}
-; DXC:       NextUnusedBit:   false
-; DXC: ...
+; CHECK: - Name:            SFI0
+; CHECK-NEXT:     Size:            8
+; CHECK-NEXT:     Flags:
+; CHECK-NEXT:       Doubles:         true
+; CHECK-NOT:   {{[A-Za-z]+: +true}}
+; CHECK:            DX11_1_DoubleExtensions:         true
+; CHECK-NOT:   {{[A-Za-z]+: +true}}
+; CHECK:       NextUnusedBit:   false
+; CHECK: ...
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
index dc4a90194262a0..e603c8a99f8152 100644
--- a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
@@ -13,7 +13,7 @@ target triple = "dxil-pc-shadermodel6.7-library"
 ; CHECK-NEXT: ;       Double-precision extensions for 11.1
 ; CHECK-NEXT: ; Note: extra DXIL module flags:
 ; CHECK-NEXT: ;
-; CHECK-NEXT: ; Shader Flags mask for Function: test_fptosi_i64
+; CHECK-NEXT: ; Shader Flags mask for Function: test_uitofp_i64
 ; CHECK-NEXT: ; Shader Flags Value: 0x00000044
 ; CHECK-NEXT: ;
 ; CHECK-NEXT: ; Note: shader requires additional functionality:
@@ -21,7 +21,7 @@ target triple = "dxil-pc-shadermodel6.7-library"
 ; CHECK-NEXT: ;       Double-precision extensions for 11.1
 ; CHECK-NEXT: ; Note: extra DXIL module flags:
 ; CHECK-NEXT: ;
-; CHECK-NEXT: ; Shader Flags mask for Function: test_fptoui_i32
+; CHECK-NEXT: ; Shader Flags mask for Function: test_sitofp_i64
 ; CHECK-NEXT: ; Shader Flags Value: 0x00000044
 ; CHECK-NEXT: ;
 ; CHECK-NEXT: ; Note: shader requires additional functionality:
@@ -29,7 +29,7 @@ target triple = "dxil-pc-shadermodel6.7-library"
 ; CHECK-NEXT: ;       Double-precision extensions for 11.1
 ; CHECK-NEXT: ; Note: extra DXIL module flags:
 ; CHECK-NEXT: ;
-; CHECK-NEXT: ; Shader Flags mask for Function: test_sitofp_i64
+; CHECK-NEXT: ; Shader Flags mask for Function: test_fptoui_i32
 ; CHECK-NEXT: ; Shader Flags Value: 0x00000044
 ; CHECK-NEXT: ;
 ; CHECK-NEXT: ; Note: shader requires additional functionality:
@@ -37,7 +37,7 @@ target triple = "dxil-pc-shadermodel6.7-library"
 ; CHECK-NEXT: ;       Double-precision extensions for 11.1
 ; CHECK-NEXT: ; Note: extra DXIL module flags:
 ; CHECK-NEXT: ;
-; CHECK-NEXT: ; Shader Flags mask for Function: test_uitofp_i64
+; CHECK-NEXT: ; Shader Flags mask for Function: test_fptosi_i64
 ; CHECK-NEXT: ; Shader Flags Value: 0x00000044
 ; CHECK-NEXT: ;
 ; CHECK-NEXT: ; Note: shader requires additional functionality:

>From 31b07705f6985f5eca0ea3c2bdff610657715cdc Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Mon, 4 Nov 2024 12:12:35 -0500
Subject: [PATCH 09/12] Address PR feedback

---
 llvm/lib/Target/DirectX/DXILShaderFlags.cpp | 37 +++++----------------
 llvm/lib/Target/DirectX/DXILShaderFlags.h   |  8 ++---
 2 files changed, 12 insertions(+), 33 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index 138b2a4b469dc2..2197f8d9d6a328 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -71,11 +71,10 @@ static void updateFlags(ComputedShaderFlags &CSF, const Instruction &I) {
 }
 
 static DXILModuleShaderFlagsInfo computeFlags(const Module &M) {
-  DXILModuleShaderFlagsInfo MSFI;
   // Construct a sorted list of functions in the module. Walk the sorted list to
   // create a list of <Function, Shader Flags Mask> pairs. This list is thus
   // sorted at construction time and may be looked up using binary search.
-  SmallVector<Function const *> FuncList;
+  SmallVector<const Function *> FuncList;
   for (const auto &F : M.getFunctionList()) {
     if (F.isDeclaration())
       continue;
@@ -83,19 +82,16 @@ static DXILModuleShaderFlagsInfo computeFlags(const Module &M) {
   }
   llvm::sort(FuncList);
 
-  MSFI.clear();
+  DXILModuleShaderFlagsInfo MSFI;
 
   // Collect shader flags for each of the functions
-  for (const auto &F : FuncList) {
+  for (const Function *F : FuncList) {
     ComputedShaderFlags CSF{};
     for (const auto &BB : *F)
       for (const auto &I : BB)
         updateFlags(CSF, I);
     // Insert shader flag mask for function F
-    if (!MSFI.insertInorderFunctionFlags(F, CSF)) {
-      M.getContext().diagnose(DiagnosticInfoShaderFlags(
-          M, "Failed to add shader flags mask for function" + F->getName()));
-    }
+    MSFI.insertInorderFunctionFlags(F, CSF);
   }
   return MSFI;
 }
@@ -118,43 +114,28 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
   OS << ";\n";
 }
 
-void DXILModuleShaderFlagsInfo::clear() {
-  ModuleFlags = ComputedShaderFlags{};
-  FunctionFlags.clear();
-}
-
 /// Insert the pair <Func, FlagMask> into the sorted vector
 /// FunctionFlags. The insertion is expected to be in-order and hence
 /// is done at the end of the already sorted list.
-[[nodiscard]] bool DXILModuleShaderFlagsInfo::insertInorderFunctionFlags(
+void DXILModuleShaderFlagsInfo::insertInorderFunctionFlags(
     const Function *Func, ComputedShaderFlags FlagMask) {
-  std::pair<Function const *, ComputedShaderFlags> V{Func, {}};
-  auto Iter = llvm::lower_bound(FunctionFlags, V);
-  if (Iter != FunctionFlags.end())
-    return false;
-
   FunctionFlags.push_back({Func, FlagMask});
-  return true;
 }
 
-SmallVector<std::pair<Function const *, ComputedShaderFlags>>
+const SmallVector<std::pair<Function const *, ComputedShaderFlags>> &
 DXILModuleShaderFlagsInfo::getFunctionFlags() const {
   return FunctionFlags;
 }
 
-ComputedShaderFlags DXILModuleShaderFlagsInfo::getModuleFlags() const {
+const ComputedShaderFlags &DXILModuleShaderFlagsInfo::getModuleFlags() const {
   return ModuleFlags;
 }
 
 Expected<const ComputedShaderFlags &>
 DXILModuleShaderFlagsInfo::getShaderFlagsMask(const Function *Func) const {
   std::pair<Function const *, ComputedShaderFlags> V{Func, {}};
-  // It is correct to delegate comparison of two pairs, say P1, P2, to default
-  // operator< for pairs that returns the evaluation of (P1.first < P2.first)
-  // viz., comparison of Function pointers - the same comparison criterion used
-  // for sorting module functions walked to form FunctionFLags vector..
-  auto Iter = llvm::lower_bound(FunctionFlags, V);
-  if (Iter == FunctionFlags.end()) {
+  const auto *Iter = llvm::lower_bound(FunctionFlags, V);
+  if (Iter == FunctionFlags.end() || Iter->first != Func) {
     return createStringError("Shader Flags information of Function '" +
                              Func->getName() + "' not found");
   }
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h
index dc3451728b8d18..562247358e2595 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.h
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h
@@ -77,12 +77,10 @@ struct DXILModuleShaderFlagsInfo {
   Expected<const ComputedShaderFlags &>
   getShaderFlagsMask(const Function *Func) const;
   bool hasShaderFlagsMask(const Function *Func) const;
-  void clear();
-  ComputedShaderFlags getModuleFlags() const;
-  SmallVector<std::pair<Function const *, ComputedShaderFlags>>
+  const ComputedShaderFlags &getModuleFlags() const;
+  const SmallVector<std::pair<Function const *, ComputedShaderFlags>> &
   getFunctionFlags() const;
-  [[nodiscard]] bool insertInorderFunctionFlags(const Function *,
-                                                ComputedShaderFlags);
+  void insertInorderFunctionFlags(const Function *, ComputedShaderFlags);
 
 private:
   // Shader Flag mask representing module-level properties. These are

>From 3427781b90b900e74d1b65e0e2fe4a89217a35ce Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Wed, 6 Nov 2024 13:42:49 -0500
Subject: [PATCH 10/12] Address PR feedback. Move the functionality of static
 void updateFlags(...) to private method void
 DXILModuleShaderFlagsInfo::updateFuctionFlags(...) and that of static
 DXILModuleShaderFlagsInfo computeFlags(const Module &M) to public method bool
 DXILModuleShaderFlagsInfo::initialize(const Module &M).

---
 llvm/lib/Target/DirectX/DXILShaderFlags.cpp | 35 +++++++++------------
 llvm/lib/Target/DirectX/DXILShaderFlags.h   |  2 ++
 2 files changed, 16 insertions(+), 21 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index 2197f8d9d6a328..23c07a3930032c 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -47,7 +47,8 @@ class DiagnosticInfoShaderFlags : public DiagnosticInfo {
 };
 } // namespace
 
-static void updateFlags(ComputedShaderFlags &CSF, const Instruction &I) {
+void DXILModuleShaderFlagsInfo::updateFuctionFlags(ComputedShaderFlags &CSF,
+                                                   const Instruction &I) {
   if (!CSF.Doubles) {
     CSF.Doubles = I.getType()->isDoubleTy();
   }
@@ -70,30 +71,20 @@ static void updateFlags(ComputedShaderFlags &CSF, const Instruction &I) {
   }
 }
 
-static DXILModuleShaderFlagsInfo computeFlags(const Module &M) {
-  // Construct a sorted list of functions in the module. Walk the sorted list to
-  // create a list of <Function, Shader Flags Mask> pairs. This list is thus
-  // sorted at construction time and may be looked up using binary search.
-  SmallVector<const Function *> FuncList;
+bool DXILModuleShaderFlagsInfo::initialize(const Module &M) {
+  // Collect shader flags for each of the functions
   for (const auto &F : M.getFunctionList()) {
     if (F.isDeclaration())
       continue;
-    FuncList.push_back(&F);
-  }
-  llvm::sort(FuncList);
-
-  DXILModuleShaderFlagsInfo MSFI;
-
-  // Collect shader flags for each of the functions
-  for (const Function *F : FuncList) {
     ComputedShaderFlags CSF{};
-    for (const auto &BB : *F)
+    for (const auto &BB : F)
       for (const auto &I : BB)
-        updateFlags(CSF, I);
+        updateFuctionFlags(CSF, I);
     // Insert shader flag mask for function F
-    MSFI.insertInorderFunctionFlags(F, CSF);
+    FunctionFlags.push_back({&F, CSF});
   }
-  return MSFI;
+  llvm::sort(FunctionFlags);
+  return true;
 }
 
 void ComputedShaderFlags::print(raw_ostream &OS) const {
@@ -134,7 +125,7 @@ const ComputedShaderFlags &DXILModuleShaderFlagsInfo::getModuleFlags() const {
 Expected<const ComputedShaderFlags &>
 DXILModuleShaderFlagsInfo::getShaderFlagsMask(const Function *Func) const {
   std::pair<Function const *, ComputedShaderFlags> V{Func, {}};
-  const auto *Iter = llvm::lower_bound(FunctionFlags, V);
+  const auto Iter = llvm::lower_bound(FunctionFlags, V);
   if (Iter == FunctionFlags.end() || Iter->first != Func) {
     return createStringError("Shader Flags information of Function '" +
                              Func->getName() + "' not found");
@@ -146,11 +137,13 @@ AnalysisKey ShaderFlagsAnalysis::Key;
 
 DXILModuleShaderFlagsInfo ShaderFlagsAnalysis::run(Module &M,
                                                    ModuleAnalysisManager &AM) {
-  return computeFlags(M);
+  DXILModuleShaderFlagsInfo MSFI;
+  MSFI.initialize(M);
+  return MSFI;
 }
 
 bool ShaderFlagsAnalysisWrapper::runOnModule(Module &M) {
-  MSFI = computeFlags(M);
+  MSFI.initialize(M);
   return false;
 }
 
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h
index 562247358e2595..de34edb5ae0a40 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.h
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h
@@ -74,6 +74,7 @@ struct ComputedShaderFlags {
 };
 
 struct DXILModuleShaderFlagsInfo {
+  bool initialize(const Module &M);
   Expected<const ComputedShaderFlags &>
   getShaderFlagsMask(const Function *Func) const;
   bool hasShaderFlagsMask(const Function *Func) const;
@@ -90,6 +91,7 @@ struct DXILModuleShaderFlagsInfo {
   // of the functions in the module. Shader Flags of each function are those
   // represented using the macro SHADER_FEATURE_FLAG.
   SmallVector<std::pair<Function const *, ComputedShaderFlags>> FunctionFlags;
+  void updateFuctionFlags(ComputedShaderFlags &CSF, const Instruction &I);
 };
 
 class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> {

>From 56af02a0c4795a2a2a697bcfbc5ed2327c5921ba Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Wed, 13 Nov 2024 13:22:16 -0500
Subject: [PATCH 11/12] Delete unused function

---
 llvm/lib/Target/DirectX/DXILShaderFlags.cpp | 8 --------
 llvm/lib/Target/DirectX/DXILShaderFlags.h   | 1 -
 2 files changed, 9 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index 23c07a3930032c..6a23e53d040b1f 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -105,14 +105,6 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
   OS << ";\n";
 }
 
-/// Insert the pair <Func, FlagMask> into the sorted vector
-/// FunctionFlags. The insertion is expected to be in-order and hence
-/// is done at the end of the already sorted list.
-void DXILModuleShaderFlagsInfo::insertInorderFunctionFlags(
-    const Function *Func, ComputedShaderFlags FlagMask) {
-  FunctionFlags.push_back({Func, FlagMask});
-}
-
 const SmallVector<std::pair<Function const *, ComputedShaderFlags>> &
 DXILModuleShaderFlagsInfo::getFunctionFlags() const {
   return FunctionFlags;
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h
index de34edb5ae0a40..be5f2a28f33e16 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.h
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h
@@ -81,7 +81,6 @@ struct DXILModuleShaderFlagsInfo {
   const ComputedShaderFlags &getModuleFlags() const;
   const SmallVector<std::pair<Function const *, ComputedShaderFlags>> &
   getFunctionFlags() const;
-  void insertInorderFunctionFlags(const Function *, ComputedShaderFlags);
 
 private:
   // Shader Flag mask representing module-level properties. These are

>From f8e501fb2daa78776edf2ec86038745a648c53c5 Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Fri, 15 Nov 2024 12:49:43 -0500
Subject: [PATCH 12/12] Use function shader flag mask for module flags; track
 combined mask

Delete DXILModuleShaderFlagsInfo::ModuleFlags and track module flags
in shader flags mask of each function.

Add private field DXILModuleShaderFlagsinfo::CombinedSFMask to
represent combined shader flags masks of all functions. Update the
value as it is computed per function.

Change DXILModuleShaderFlagsInfo::initialize(Module&) to constructor
---
 .../lib/Target/DirectX/DXContainerGlobals.cpp | 12 ++---
 llvm/lib/Target/DirectX/DXILShaderFlags.cpp   | 47 ++++++++++---------
 llvm/lib/Target/DirectX/DXILShaderFlags.h     | 38 +++++++++------
 .../Target/DirectX/DXILTranslateMetadata.cpp  | 24 ++++------
 .../DirectX/ShaderFlags/double-extensions.ll  |  5 +-
 .../CodeGen/DirectX/ShaderFlags/doubles.ll    |  2 -
 .../CodeGen/DirectX/ShaderFlags/no_flags.ll   |  3 --
 7 files changed, 63 insertions(+), 68 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
index ef2333f01f752e..181fb86892f9fd 100644
--- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
+++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
@@ -78,18 +78,16 @@ bool DXContainerGlobals::runOnModule(Module &M) {
 }
 
 GlobalVariable *DXContainerGlobals::getFeatureFlags(Module &M) {
-  const DXILModuleShaderFlagsInfo &MSFI =
-      getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags();
   // TODO: Feature flags mask is obtained as a collection of feature flags
   // of the shader flags of all functions in the module. Need to verify
   // and modify the computation of feature flags to be used.
-  uint64_t ConsolidatedFeatureFlags = 0;
-  for (const auto &FuncFlags : MSFI.getFunctionFlags()) {
-    ConsolidatedFeatureFlags |= FuncFlags.second.getFeatureFlags();
-  }
+  uint64_t CombinedFeatureFlags = getAnalysis<ShaderFlagsAnalysisWrapper>()
+                                      .getShaderFlags()
+                                      .getCombinedFlags()
+                                      .getFeatureFlags();
 
   Constant *FeatureFlagsConstant =
-      ConstantInt::get(M.getContext(), APInt(64, ConsolidatedFeatureFlags));
+      ConstantInt::get(M.getContext(), APInt(64, CombinedFeatureFlags));
   return buildContainerGlobal(M, FeatureFlagsConstant, "dx.sfi0", "SFI0");
 }
 
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index 6a23e53d040b1f..f6e18b9708c962 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -47,8 +47,8 @@ class DiagnosticInfoShaderFlags : public DiagnosticInfo {
 };
 } // namespace
 
-void DXILModuleShaderFlagsInfo::updateFuctionFlags(ComputedShaderFlags &CSF,
-                                                   const Instruction &I) {
+void DXILModuleShaderFlagsInfo::updateFunctionFlags(ComputedShaderFlags &CSF,
+                                                    const Instruction &I) {
   if (!CSF.Doubles) {
     CSF.Doubles = I.getType()->isDoubleTy();
   }
@@ -71,7 +71,7 @@ void DXILModuleShaderFlagsInfo::updateFuctionFlags(ComputedShaderFlags &CSF,
   }
 }
 
-bool DXILModuleShaderFlagsInfo::initialize(const Module &M) {
+DXILModuleShaderFlagsInfo::DXILModuleShaderFlagsInfo(const Module &M) {
   // Collect shader flags for each of the functions
   for (const auto &F : M.getFunctionList()) {
     if (F.isDeclaration())
@@ -79,12 +79,13 @@ bool DXILModuleShaderFlagsInfo::initialize(const Module &M) {
     ComputedShaderFlags CSF{};
     for (const auto &BB : F)
       for (const auto &I : BB)
-        updateFuctionFlags(CSF, I);
+        updateFunctionFlags(CSF, I);
     // Insert shader flag mask for function F
     FunctionFlags.push_back({&F, CSF});
+    // Update combined shader flags mask
+    CombinedSFMask |= CSF;
   }
   llvm::sort(FunctionFlags);
-  return true;
 }
 
 void ComputedShaderFlags::print(raw_ostream &OS) const {
@@ -105,15 +106,13 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
   OS << ";\n";
 }
 
-const SmallVector<std::pair<Function const *, ComputedShaderFlags>> &
-DXILModuleShaderFlagsInfo::getFunctionFlags() const {
-  return FunctionFlags;
-}
-
-const ComputedShaderFlags &DXILModuleShaderFlagsInfo::getModuleFlags() const {
-  return ModuleFlags;
+/// Get the combined shader flag mask of all module functions.
+const ComputedShaderFlags DXILModuleShaderFlagsInfo::getCombinedFlags() const {
+  return CombinedSFMask;
 }
 
+/// Return the shader flags mask of the specified function Func, if one exists.
+/// else an error
 Expected<const ComputedShaderFlags &>
 DXILModuleShaderFlagsInfo::getShaderFlagsMask(const Function *Func) const {
   std::pair<Function const *, ComputedShaderFlags> V{Func, {}};
@@ -125,25 +124,21 @@ DXILModuleShaderFlagsInfo::getShaderFlagsMask(const Function *Func) const {
   return Iter->second;
 }
 
+//===----------------------------------------------------------------------===//
+// ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass
+
+// Provide an explicit template instantiation for the static ID.
 AnalysisKey ShaderFlagsAnalysis::Key;
 
 DXILModuleShaderFlagsInfo ShaderFlagsAnalysis::run(Module &M,
                                                    ModuleAnalysisManager &AM) {
-  DXILModuleShaderFlagsInfo MSFI;
-  MSFI.initialize(M);
+  DXILModuleShaderFlagsInfo MSFI(M);
   return MSFI;
 }
 
-bool ShaderFlagsAnalysisWrapper::runOnModule(Module &M) {
-  MSFI.initialize(M);
-  return false;
-}
-
 PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
                                                   ModuleAnalysisManager &AM) {
   DXILModuleShaderFlagsInfo FlagsInfo = AM.getResult<ShaderFlagsAnalysis>(M);
-  OS << "; Shader Flags mask for Module:\n";
-  FlagsInfo.getModuleFlags().print(OS);
   for (const auto &F : M.getFunctionList()) {
     if (F.isDeclaration())
       continue;
@@ -159,6 +154,16 @@ PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
   return PreservedAnalyses::all();
 }
 
+//===----------------------------------------------------------------------===//
+// ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass
+
+bool ShaderFlagsAnalysisWrapper::runOnModule(Module &M) {
+  MSFI.reset(new DXILModuleShaderFlagsInfo(M));
+  return false;
+}
+
+void ShaderFlagsAnalysisWrapper::releaseMemory() { MSFI.reset(); }
+
 char ShaderFlagsAnalysisWrapper::ID = 0;
 
 INITIALIZE_PASS(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis",
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h
index be5f2a28f33e16..396cae921f9b18 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.h
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h
@@ -21,6 +21,7 @@
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 #include <cstdint>
+#include <memory>
 
 namespace llvm {
 class Module;
@@ -69,28 +70,34 @@ struct ComputedShaderFlags {
     return ModuleFlags;
   }
 
+  ComputedShaderFlags &operator|=(const uint64_t IVal) {
+#define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleBit, FlagName, Str)          \
+  FlagName |= (IVal & getMask(DxilModuleBit));
+#define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str)                         \
+  FlagName |= (IVal & getMask(DxilModuleBit));
+#include "llvm/BinaryFormat/DXContainerConstants.def"
+    return *this;
+  }
+
   void print(raw_ostream &OS = dbgs()) const;
   LLVM_DUMP_METHOD void dump() const { print(); }
 };
 
 struct DXILModuleShaderFlagsInfo {
-  bool initialize(const Module &M);
+  DXILModuleShaderFlagsInfo(const Module &);
   Expected<const ComputedShaderFlags &>
-  getShaderFlagsMask(const Function *Func) const;
-  bool hasShaderFlagsMask(const Function *Func) const;
-  const ComputedShaderFlags &getModuleFlags() const;
-  const SmallVector<std::pair<Function const *, ComputedShaderFlags>> &
-  getFunctionFlags() const;
+  getShaderFlagsMask(const Function *) const;
+  const ComputedShaderFlags getCombinedFlags() const;
 
 private:
-  // Shader Flag mask representing module-level properties. These are
-  // represented using the macro DXIL_MODULE_FLAG
-  ComputedShaderFlags ModuleFlags;
-  // Vector of Function-Shader Flag mask pairs representing properties of each
-  // of the functions in the module. Shader Flags of each function are those
-  // represented using the macro SHADER_FEATURE_FLAG.
+  /// Vector of Function-Shader Flag mask pairs representing properties of each
+  /// of the functions in the module. Shader Flags of each function represent
+  /// both module-level and function-level flags
   SmallVector<std::pair<Function const *, ComputedShaderFlags>> FunctionFlags;
-  void updateFuctionFlags(ComputedShaderFlags &CSF, const Instruction &I);
+  /// Combined Shader Flag Mask of all functions of the module
+  ComputedShaderFlags CombinedSFMask{};
+
+  void updateFunctionFlags(ComputedShaderFlags &CSF, const Instruction &I);
 };
 
 class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> {
@@ -120,16 +127,17 @@ class ShaderFlagsAnalysisPrinter
 /// This is required because the passes that will depend on this are codegen
 /// passes which run through the legacy pass manager.
 class ShaderFlagsAnalysisWrapper : public ModulePass {
-  DXILModuleShaderFlagsInfo MSFI;
+  std::unique_ptr<DXILModuleShaderFlagsInfo> MSFI;
 
 public:
   static char ID;
 
   ShaderFlagsAnalysisWrapper() : ModulePass(ID) {}
 
-  const DXILModuleShaderFlagsInfo &getShaderFlags() { return MSFI; }
+  const DXILModuleShaderFlagsInfo &getShaderFlags() { return *MSFI; }
 
   bool runOnModule(Module &M) override;
+  void releaseMemory() override;
 
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     AU.setPreservesAll();
diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index 069469c66b2d97..bc8d100c53fc78 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -315,24 +315,21 @@ static void translateMetadata(Module &M, const DXILResourceMap &DRM,
   MDTuple *Signatures = nullptr;
 
   if (MMDI.ShaderProfile == Triple::EnvironmentType::Library) {
-    // Create a consolidated shader flag mask of all functions in the library
-    // to be used as shader flags mask value associated with top-level library
-    // entry metadata.
-    uint64_t ConsolidatedMask = ShaderFlags.getModuleFlags();
-    for (const auto &FunFlags : ShaderFlags.getFunctionFlags()) {
-      ConsolidatedMask |= FunFlags.second;
-    }
+    // Get the combined shader flag mask of all functions in the library to be
+    // used as shader flags mask value associated with top-level library entry
+    // metadata.
+    uint64_t CombinedMask = ShaderFlags.getCombinedFlags();
     EntryFnMDNodes.emplace_back(
-        emitTopLevelLibraryNode(M, ResourceMD, ConsolidatedMask));
+        emitTopLevelLibraryNode(M, ResourceMD, CombinedMask));
   } else if (MMDI.EntryPropertyVec.size() > 1) {
     M.getContext().diagnose(DiagnosticInfoTranslateMD(
         M, "Non-library shader: One and only one entry expected"));
   }
 
   for (const EntryProperties &EntryProp : MMDI.EntryPropertyVec) {
-    Expected<const ComputedShaderFlags &> ECSF =
+    Expected<const ComputedShaderFlags &> EntrySFMask =
         ShaderFlags.getShaderFlagsMask(EntryProp.Entry);
-    if (Error E = ECSF.takeError()) {
+    if (Error E = EntrySFMask.takeError()) {
       M.getContext().diagnose(
           DiagnosticInfoTranslateMD(M, toString(std::move(E))));
     }
@@ -341,12 +338,7 @@ static void translateMetadata(Module &M, const DXILResourceMap &DRM,
     // top-level library node. Hence it is not emitted.
     uint64_t EntryShaderFlags = 0;
     if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
-      // TODO: Create a consolidated shader flag mask of all the entry
-      // functions and its callees. The following is correct only if
-      // EntryProp.Entry has no call instructions.
-      EntryShaderFlags = *ECSF | ShaderFlags.getModuleFlags();
-    }
-    if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
+      EntryShaderFlags = *EntrySFMask;
       if (EntryProp.ShaderStage != MMDI.ShaderProfile) {
         M.getContext().diagnose(DiagnosticInfoTranslateMD(
             M,
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
index e603c8a99f8152..2ba6822dccbc3f 100644
--- a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll
@@ -2,10 +2,7 @@
 
 target triple = "dxil-pc-shadermodel6.7-library"
 
-; CHECK: ; Shader Flags mask for Module:
-; CHECK-NEXT: ; Shader Flags Value: 0x00000000
-; CHECK-NEXT: ;
-; CHECK-NEXT: ; Shader Flags mask for Function: test_fdiv_double
+; CHECK: ; Shader Flags mask for Function: test_fdiv_double
 ; CHECK-NEXT: ; Shader Flags Value: 0x00000044
 ; CHECK-NEXT: ;
 ; CHECK-NEXT: ; Note: shader requires additional functionality:
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/doubles.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/doubles.ll
index b4276c2144af97..01b3c28ac1a67d 100644
--- a/llvm/test/CodeGen/DirectX/ShaderFlags/doubles.ll
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/doubles.ll
@@ -3,8 +3,6 @@
 
 target triple = "dxil-pc-shadermodel6.7-library"
 
-; CHECK: ; Shader Flags mask for Module:
-; CHECK-NEXT: ; Shader Flags Value: 0x00000000
 ; CHECK: ; Shader Flags mask for Function: add
 ; CHECK-NEXT: ; Shader Flags Value: 0x00000004
 ; CHECK: ; Note: shader requires additional functionality:
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/no_flags.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/no_flags.ll
index df40fe9d86eda1..e3f3cbdd9fc28b 100644
--- a/llvm/test/CodeGen/DirectX/ShaderFlags/no_flags.ll
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/no_flags.ll
@@ -2,9 +2,6 @@
 
 target triple = "dxil-pc-shadermodel6.7-library"
 
-; CHECK: ; Shader Flags mask for Module:
-; CHECK-NEXT: ; Shader Flags Value: 0x00000000
-;
 ; CHECK: ; Shader Flags mask for Function: add
 ; CHECK-NEXT: ; Shader Flags Value: 0x00000000
 define i32 @add(i32 %a, i32 %b) {



More information about the llvm-commits mailing list