[llvm] [DXIL] Consume Metadata Analysis information in passes (PR #108034)

S. Bharadwaj Yadavalli via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 19 09:59:06 PDT 2024


https://github.com/bharadwajy updated https://github.com/llvm/llvm-project/pull/108034

>From 4498e75d291287fac9042b9703f104bc412d49d5 Mon Sep 17 00:00:00 2001
From: Bharadwaj Yadavalli <Bharadwaj.Yadavalli at microsoft.com>
Date: Wed, 21 Aug 2024 15:35:17 -0400
Subject: [PATCH 1/8] [DXIL] Consume Metadata Analysis information in
 DXILTranslateMetadata and DXILPrepare passes.

---
 .../llvm/Analysis/DXILMetadataAnalysis.h      |   6 +-
 llvm/lib/Analysis/DXILMetadataAnalysis.cpp    |   8 +-
 llvm/lib/Target/DirectX/CMakeLists.txt        |   1 -
 .../lib/Target/DirectX/DXContainerGlobals.cpp |  10 +-
 llvm/lib/Target/DirectX/DXILMetadata.cpp      | 335 ------------------
 llvm/lib/Target/DirectX/DXILMetadata.h        |  43 ---
 llvm/lib/Target/DirectX/DXILPrepare.cpp       |  13 +-
 .../Target/DirectX/DXILTranslateMetadata.cpp  | 286 ++++++++++++++-
 .../Target/DirectX/DirectXTargetMachine.cpp   |   1 +
 .../CodeGen/DirectX/legalize-module-flags.ll  |   2 +-
 .../CodeGen/DirectX/legalize-module-flags2.ll |   2 +-
 llvm/test/CodeGen/DirectX/strip-call-attrs.ll |   2 +-
 llvm/test/CodeGen/DirectX/typed_ptr.ll        |   2 +-
 13 files changed, 298 insertions(+), 413 deletions(-)
 delete mode 100644 llvm/lib/Target/DirectX/DXILMetadata.cpp
 delete mode 100644 llvm/lib/Target/DirectX/DXILMetadata.h

diff --git a/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h b/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h
index ed342c28b2d78b..cb442669a24dfe 100644
--- a/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h
+++ b/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h
@@ -21,20 +21,20 @@ class Function;
 namespace dxil {
 
 struct EntryProperties {
-  const Function *Entry;
+  const Function *Entry{nullptr};
   // Specific target shader stage may be specified for entry functions
   Triple::EnvironmentType ShaderStage = Triple::UnknownEnvironment;
   unsigned NumThreadsX{0}; // X component
   unsigned NumThreadsY{0}; // Y component
   unsigned NumThreadsZ{0}; // Z component
 
-  EntryProperties(const Function &Fn) : Entry(&Fn) {};
+  EntryProperties(const Function *Fn = nullptr) : Entry(Fn) {};
 };
 
 struct ModuleMetadataInfo {
   VersionTuple DXILVersion{};
   VersionTuple ShaderModelVersion{};
-  Triple::EnvironmentType ShaderStage = Triple::UnknownEnvironment;
+  Triple::EnvironmentType ShaderProfile = Triple::UnknownEnvironment;
   VersionTuple ValidatorVersion{};
   SmallVector<EntryProperties> EntryPropertyVec{};
   void print(raw_ostream &OS) const;
diff --git a/llvm/lib/Analysis/DXILMetadataAnalysis.cpp b/llvm/lib/Analysis/DXILMetadataAnalysis.cpp
index cebfe4b84dcdfb..a7f666a3f8b48f 100644
--- a/llvm/lib/Analysis/DXILMetadataAnalysis.cpp
+++ b/llvm/lib/Analysis/DXILMetadataAnalysis.cpp
@@ -27,7 +27,7 @@ static ModuleMetadataInfo collectMetadataInfo(Module &M) {
   Triple TT(Triple(M.getTargetTriple()));
   MMDAI.DXILVersion = TT.getDXILVersion();
   MMDAI.ShaderModelVersion = TT.getOSVersion();
-  MMDAI.ShaderStage = TT.getEnvironment();
+  MMDAI.ShaderProfile = TT.getEnvironment();
   NamedMDNode *ValidatorVerNode = M.getNamedMetadata("dx.valver");
   if (ValidatorVerNode) {
     auto *ValVerMD = cast<MDNode>(ValidatorVerNode->getOperand(0));
@@ -42,7 +42,7 @@ static ModuleMetadataInfo collectMetadataInfo(Module &M) {
     if (!F.hasFnAttribute("hlsl.shader"))
       continue;
 
-    EntryProperties EFP(F);
+    EntryProperties EFP(&F);
     // Get "hlsl.shader" attribute
     Attribute EntryAttr = F.getFnAttribute("hlsl.shader");
     assert(EntryAttr.isValid() &&
@@ -74,8 +74,8 @@ static ModuleMetadataInfo collectMetadataInfo(Module &M) {
 void ModuleMetadataInfo::print(raw_ostream &OS) const {
   OS << "Shader Model Version : " << ShaderModelVersion.getAsString() << "\n";
   OS << "DXIL Version : " << DXILVersion.getAsString() << "\n";
-  OS << "Target Shader Stage : " << Triple::getEnvironmentTypeName(ShaderStage)
-     << "\n";
+  OS << "Target Shader Stage : "
+     << Triple::getEnvironmentTypeName(ShaderProfile) << "\n";
   OS << "Validator Version : " << ValidatorVersion.getAsString() << "\n";
   for (const auto &EP : EntryPropertyVec) {
     OS << " " << EP.Entry->getName() << "\n";
diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt
index f7ae09957996b5..55d32deb49b085 100644
--- a/llvm/lib/Target/DirectX/CMakeLists.txt
+++ b/llvm/lib/Target/DirectX/CMakeLists.txt
@@ -21,7 +21,6 @@ add_llvm_target(DirectXCodeGen
   DXContainerGlobals.cpp
   DXILFinalizeLinkage.cpp
   DXILIntrinsicExpansion.cpp
-  DXILMetadata.cpp
   DXILOpBuilder.cpp
   DXILOpLowering.cpp
   DXILPrepare.cpp
diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
index 839060badf0747..2c11373504e8c7 100644
--- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
+++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
@@ -204,9 +204,9 @@ void DXContainerGlobals::addPipelineStateValidationInfo(
   dxil::ModuleMetadataInfo &MMI =
       getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
   assert(MMI.EntryPropertyVec.size() == 1 ||
-         MMI.ShaderStage == Triple::Library);
+         MMI.ShaderProfile == Triple::Library);
   PSV.BaseData.ShaderStage =
-      static_cast<uint8_t>(MMI.ShaderStage - Triple::Pixel);
+      static_cast<uint8_t>(MMI.ShaderProfile - Triple::Pixel);
 
   addResourcesForPSV(M, PSV);
 
@@ -215,7 +215,7 @@ void DXContainerGlobals::addPipelineStateValidationInfo(
   // TODO: Lots more stuff to do here!
   //
   // See issue https://github.com/llvm/llvm-project/issues/96674.
-  switch (MMI.ShaderStage) {
+  switch (MMI.ShaderProfile) {
   case Triple::Compute:
     PSV.BaseData.NumThreadsX = MMI.EntryPropertyVec[0].NumThreadsX;
     PSV.BaseData.NumThreadsY = MMI.EntryPropertyVec[0].NumThreadsY;
@@ -225,10 +225,10 @@ void DXContainerGlobals::addPipelineStateValidationInfo(
     break;
   }
 
-  if (MMI.ShaderStage != Triple::Library)
+  if (MMI.ShaderProfile != Triple::Library)
     PSV.EntryName = MMI.EntryPropertyVec[0].Entry->getName();
 
-  PSV.finalize(MMI.ShaderStage);
+  PSV.finalize(MMI.ShaderProfile);
   PSV.write(OS);
   Constant *Constant =
       ConstantDataArray::getString(M.getContext(), Data, /*AddNull*/ false);
diff --git a/llvm/lib/Target/DirectX/DXILMetadata.cpp b/llvm/lib/Target/DirectX/DXILMetadata.cpp
deleted file mode 100644
index 1f5759c3630135..00000000000000
--- a/llvm/lib/Target/DirectX/DXILMetadata.cpp
+++ /dev/null
@@ -1,335 +0,0 @@
-//===- DXILMetadata.cpp - DXIL Metadata helper objects --------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-///
-/// \file This file contains helper objects for working with DXIL metadata.
-///
-//===----------------------------------------------------------------------===//
-
-#include "DXILMetadata.h"
-#include "llvm/IR/Constants.h"
-#include "llvm/IR/IRBuilder.h"
-#include "llvm/IR/Metadata.h"
-#include "llvm/IR/Module.h"
-#include "llvm/Support/VersionTuple.h"
-#include "llvm/TargetParser/Triple.h"
-
-using namespace llvm;
-using namespace llvm::dxil;
-
-ValidatorVersionMD::ValidatorVersionMD(Module &M)
-    : Entry(M.getOrInsertNamedMetadata("dx.valver")) {}
-
-void ValidatorVersionMD::update(VersionTuple ValidatorVer) {
-  auto &Ctx = Entry->getParent()->getContext();
-  IRBuilder<> B(Ctx);
-  Metadata *MDVals[2];
-  MDVals[0] = ConstantAsMetadata::get(B.getInt32(ValidatorVer.getMajor()));
-  MDVals[1] =
-      ConstantAsMetadata::get(B.getInt32(ValidatorVer.getMinor().value_or(0)));
-
-  if (isEmpty())
-    Entry->addOperand(MDNode::get(Ctx, MDVals));
-  else
-    Entry->setOperand(0, MDNode::get(Ctx, MDVals));
-}
-
-bool ValidatorVersionMD::isEmpty() { return Entry->getNumOperands() == 0; }
-
-VersionTuple ValidatorVersionMD::getAsVersionTuple() {
-  if (isEmpty())
-    return VersionTuple(1, 0);
-  auto *ValVerMD = cast<MDNode>(Entry->getOperand(0));
-  auto *MajorMD = mdconst::extract<ConstantInt>(ValVerMD->getOperand(0));
-  auto *MinorMD = mdconst::extract<ConstantInt>(ValVerMD->getOperand(1));
-  return VersionTuple(MajorMD->getZExtValue(), MinorMD->getZExtValue());
-}
-
-static StringRef getShortShaderStage(Triple::EnvironmentType Env) {
-  switch (Env) {
-  case Triple::Pixel:
-    return "ps";
-  case Triple::Vertex:
-    return "vs";
-  case Triple::Geometry:
-    return "gs";
-  case Triple::Hull:
-    return "hs";
-  case Triple::Domain:
-    return "ds";
-  case Triple::Compute:
-    return "cs";
-  case Triple::Library:
-    return "lib";
-  case Triple::Mesh:
-    return "ms";
-  case Triple::Amplification:
-    return "as";
-  default:
-    break;
-  }
-  llvm_unreachable("Unsupported environment for DXIL generation.");
-  return "";
-}
-
-void dxil::createShaderModelMD(Module &M) {
-  NamedMDNode *Entry = M.getOrInsertNamedMetadata("dx.shaderModel");
-  Triple TT(M.getTargetTriple());
-  VersionTuple Ver = TT.getOSVersion();
-  LLVMContext &Ctx = M.getContext();
-  IRBuilder<> B(Ctx);
-
-  Metadata *Vals[3];
-  Vals[0] = MDString::get(Ctx, getShortShaderStage(TT.getEnvironment()));
-  Vals[1] = ConstantAsMetadata::get(B.getInt32(Ver.getMajor()));
-  Vals[2] = ConstantAsMetadata::get(B.getInt32(Ver.getMinor().value_or(0)));
-  Entry->addOperand(MDNode::get(Ctx, Vals));
-}
-
-void dxil::createDXILVersionMD(Module &M) {
-  Triple TT(Triple::normalize(M.getTargetTriple()));
-  VersionTuple Ver = TT.getDXILVersion();
-  LLVMContext &Ctx = M.getContext();
-  IRBuilder<> B(Ctx);
-  NamedMDNode *Entry = M.getOrInsertNamedMetadata("dx.version");
-  Metadata *Vals[2];
-  Vals[0] = ConstantAsMetadata::get(B.getInt32(Ver.getMajor()));
-  Vals[1] = ConstantAsMetadata::get(B.getInt32(Ver.getMinor().value_or(0)));
-  Entry->addOperand(MDNode::get(Ctx, Vals));
-}
-
-static uint32_t getShaderStage(Triple::EnvironmentType Env) {
-  return (uint32_t)Env - (uint32_t)llvm::Triple::Pixel;
-}
-
-namespace {
-
-struct EntryProps {
-  Triple::EnvironmentType ShaderKind;
-  // FIXME: support more shader profiles.
-  // See https://github.com/llvm/llvm-project/issues/57927.
-  struct {
-    unsigned NumThreads[3];
-  } CS;
-
-  EntryProps(Function &F, Triple::EnvironmentType ModuleShaderKind)
-      : ShaderKind(ModuleShaderKind) {
-
-    if (ShaderKind == Triple::EnvironmentType::Library) {
-      Attribute EntryAttr = F.getFnAttribute("hlsl.shader");
-      StringRef EntryProfile = EntryAttr.getValueAsString();
-      Triple T("", "", "", EntryProfile);
-      ShaderKind = T.getEnvironment();
-    }
-
-    if (ShaderKind == Triple::EnvironmentType::Compute) {
-      auto NumThreadsStr =
-          F.getFnAttribute("hlsl.numthreads").getValueAsString();
-      SmallVector<StringRef> NumThreads;
-      NumThreadsStr.split(NumThreads, ',');
-      assert(NumThreads.size() == 3 && "invalid numthreads");
-      auto Zip =
-          llvm::zip(NumThreads, MutableArrayRef<unsigned>(CS.NumThreads));
-      for (auto It : Zip) {
-        StringRef Str = std::get<0>(It);
-        APInt V;
-        [[maybe_unused]] bool Result = Str.getAsInteger(10, V);
-        assert(!Result && "Failed to parse numthreads");
-
-        unsigned &Num = std::get<1>(It);
-        Num = V.getLimitedValue();
-      }
-    }
-  }
-
-  MDTuple *emitDXILEntryProps(uint64_t RawShaderFlag, LLVMContext &Ctx,
-                              bool IsLib) {
-    std::vector<Metadata *> MDVals;
-
-    if (RawShaderFlag != 0)
-      appendShaderFlags(MDVals, RawShaderFlag, Ctx);
-
-    // Add shader kind for lib entrys.
-    if (IsLib && ShaderKind != Triple::EnvironmentType::Library)
-      appendShaderKind(MDVals, Ctx);
-
-    if (ShaderKind == Triple::EnvironmentType::Compute)
-      appendNumThreads(MDVals, Ctx);
-    // FIXME: support more props.
-    // See https://github.com/llvm/llvm-project/issues/57948.
-    return MDNode::get(Ctx, MDVals);
-  }
-
-  static MDTuple *emitEntryPropsForEmptyEntry(uint64_t RawShaderFlag,
-                                              LLVMContext &Ctx) {
-    if (RawShaderFlag == 0)
-      return nullptr;
-
-    std::vector<Metadata *> MDVals;
-
-    appendShaderFlags(MDVals, RawShaderFlag, Ctx);
-    // FIXME: support more props.
-    // See https://github.com/llvm/llvm-project/issues/57948.
-    return MDNode::get(Ctx, MDVals);
-  }
-
-private:
-  enum EntryPropsTag {
-    ShaderFlagsTag = 0,
-    GSStateTag,
-    DSStateTag,
-    HSStateTag,
-    NumThreadsTag,
-    AutoBindingSpaceTag,
-    RayPayloadSizeTag,
-    RayAttribSizeTag,
-    ShaderKindTag,
-    MSStateTag,
-    ASStateTag,
-    WaveSizeTag,
-    EntryRootSigTag,
-  };
-
-  void appendNumThreads(std::vector<Metadata *> &MDVals, LLVMContext &Ctx) {
-    MDVals.emplace_back(ConstantAsMetadata::get(
-        ConstantInt::get(Type::getInt32Ty(Ctx), NumThreadsTag)));
-
-    std::vector<Metadata *> NumThreadVals;
-    for (auto Num : ArrayRef<unsigned>(CS.NumThreads))
-      NumThreadVals.emplace_back(ConstantAsMetadata::get(
-          ConstantInt::get(Type::getInt32Ty(Ctx), Num)));
-    MDVals.emplace_back(MDNode::get(Ctx, NumThreadVals));
-  }
-
-  static void appendShaderFlags(std::vector<Metadata *> &MDVals,
-                                uint64_t RawShaderFlag, LLVMContext &Ctx) {
-    MDVals.emplace_back(ConstantAsMetadata::get(
-        ConstantInt::get(Type::getInt32Ty(Ctx), ShaderFlagsTag)));
-    MDVals.emplace_back(ConstantAsMetadata::get(
-        ConstantInt::get(Type::getInt64Ty(Ctx), RawShaderFlag)));
-  }
-
-  void appendShaderKind(std::vector<Metadata *> &MDVals, LLVMContext &Ctx) {
-    MDVals.emplace_back(ConstantAsMetadata::get(
-        ConstantInt::get(Type::getInt32Ty(Ctx), ShaderKindTag)));
-    MDVals.emplace_back(ConstantAsMetadata::get(
-        ConstantInt::get(Type::getInt32Ty(Ctx), getShaderStage(ShaderKind))));
-  }
-};
-
-class EntryMD {
-  Function &F;
-  LLVMContext &Ctx;
-  EntryProps Props;
-
-public:
-  EntryMD(Function &F, Triple::EnvironmentType ModuleShaderKind)
-      : F(F), Ctx(F.getContext()), Props(F, ModuleShaderKind) {}
-
-  MDTuple *emitEntryTuple(MDTuple *Resources, uint64_t RawShaderFlag) {
-    // FIXME: add signature for profile other than CS.
-    // See https://github.com/llvm/llvm-project/issues/57928.
-    MDTuple *Signatures = nullptr;
-    return emitDXILEntryPointTuple(
-        &F, F.getName().str(), Signatures, Resources,
-        Props.emitDXILEntryProps(RawShaderFlag, Ctx, /*IsLib*/ false), Ctx);
-  }
-
-  MDTuple *emitEntryTupleForLib(uint64_t RawShaderFlag) {
-    // FIXME: add signature for profile other than CS.
-    // See https://github.com/llvm/llvm-project/issues/57928.
-    MDTuple *Signatures = nullptr;
-    return emitDXILEntryPointTuple(
-        &F, F.getName().str(), Signatures,
-        /*entry in lib doesn't need resources metadata*/ nullptr,
-        Props.emitDXILEntryProps(RawShaderFlag, Ctx, /*IsLib*/ true), Ctx);
-  }
-
-  // Library will have empty entry metadata which only store the resource table
-  // metadata.
-  static MDTuple *emitEmptyEntryForLib(MDTuple *Resources,
-                                       uint64_t RawShaderFlag,
-                                       LLVMContext &Ctx) {
-    return emitDXILEntryPointTuple(
-        nullptr, "", nullptr, Resources,
-        EntryProps::emitEntryPropsForEmptyEntry(RawShaderFlag, Ctx), Ctx);
-  }
-
-private:
-  static MDTuple *emitDXILEntryPointTuple(Function *Fn, const std::string &Name,
-                                          MDTuple *Signatures,
-                                          MDTuple *Resources,
-                                          MDTuple *Properties,
-                                          LLVMContext &Ctx) {
-    Metadata *MDVals[5];
-    MDVals[0] = Fn ? ValueAsMetadata::get(Fn) : nullptr;
-    MDVals[1] = MDString::get(Ctx, Name.c_str());
-    MDVals[2] = Signatures;
-    MDVals[3] = Resources;
-    MDVals[4] = Properties;
-    return MDNode::get(Ctx, MDVals);
-  }
-};
-} // namespace
-
-void dxil::createEntryMD(Module &M, const uint64_t ShaderFlags) {
-  SmallVector<Function *> EntryList;
-  for (auto &F : M.functions()) {
-    if (!F.hasFnAttribute("hlsl.shader"))
-      continue;
-    EntryList.emplace_back(&F);
-  }
-
-  // If there are no entries, do nothing. This is mostly to allow for writing
-  // tests with no actual entry functions.
-  if (EntryList.empty())
-    return;
-
-  auto &Ctx = M.getContext();
-  // FIXME: generate metadata for resource.
-  // See https://github.com/llvm/llvm-project/issues/57926.
-  MDTuple *MDResources = nullptr;
-  if (auto *NamedResources = M.getNamedMetadata("dx.resources"))
-    MDResources = dyn_cast<MDTuple>(NamedResources->getOperand(0));
-
-  std::vector<MDNode *> Entries;
-  Triple T = Triple(M.getTargetTriple());
-  switch (T.getEnvironment()) {
-  case Triple::EnvironmentType::Library: {
-    // Add empty entry to put resource metadata.
-    MDTuple *EmptyEntry =
-        EntryMD::emitEmptyEntryForLib(MDResources, ShaderFlags, Ctx);
-    Entries.emplace_back(EmptyEntry);
-
-    for (Function *Entry : EntryList) {
-      EntryMD MD(*Entry, T.getEnvironment());
-      Entries.emplace_back(MD.emitEntryTupleForLib(0));
-    }
-  } break;
-  case Triple::EnvironmentType::Compute:
-  case Triple::EnvironmentType::Amplification:
-  case Triple::EnvironmentType::Mesh:
-  case Triple::EnvironmentType::Vertex:
-  case Triple::EnvironmentType::Hull:
-  case Triple::EnvironmentType::Domain:
-  case Triple::EnvironmentType::Geometry:
-  case Triple::EnvironmentType::Pixel: {
-    assert(EntryList.size() == 1 &&
-           "non-lib profiles should only have one entry");
-    EntryMD MD(*EntryList.front(), T.getEnvironment());
-    Entries.emplace_back(MD.emitEntryTuple(MDResources, ShaderFlags));
-  } break;
-  default:
-    assert(0 && "invalid profile");
-    break;
-  }
-
-  NamedMDNode *EntryPointsNamedMD =
-      M.getOrInsertNamedMetadata("dx.entryPoints");
-  for (auto *Entry : Entries)
-    EntryPointsNamedMD->addOperand(Entry);
-}
diff --git a/llvm/lib/Target/DirectX/DXILMetadata.h b/llvm/lib/Target/DirectX/DXILMetadata.h
deleted file mode 100644
index e05db8d5370dbe..00000000000000
--- a/llvm/lib/Target/DirectX/DXILMetadata.h
+++ /dev/null
@@ -1,43 +0,0 @@
-//===- DXILMetadata.h - DXIL Metadata helper objects ----------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-///
-/// \file This file contains helper objects for working with DXIL metadata.
-///
-//===----------------------------------------------------------------------===//
-
-#ifndef LLVM_TARGET_DIRECTX_DXILMETADATA_H
-#define LLVM_TARGET_DIRECTX_DXILMETADATA_H
-
-#include <stdint.h>
-
-namespace llvm {
-class Module;
-class NamedMDNode;
-class VersionTuple;
-namespace dxil {
-
-class ValidatorVersionMD {
-  NamedMDNode *Entry;
-
-public:
-  ValidatorVersionMD(Module &M);
-
-  void update(VersionTuple ValidatorVer);
-
-  bool isEmpty();
-  VersionTuple getAsVersionTuple();
-};
-
-void createShaderModelMD(Module &M);
-void createDXILVersionMD(Module &M);
-void createEntryMD(Module &M, const uint64_t ShaderFlags);
-
-} // namespace dxil
-} // namespace llvm
-
-#endif // LLVM_TARGET_DIRECTX_DXILMETADATA_H
diff --git a/llvm/lib/Target/DirectX/DXILPrepare.cpp b/llvm/lib/Target/DirectX/DXILPrepare.cpp
index b050240041dd2e..1a766c5fb7b4a8 100644
--- a/llvm/lib/Target/DirectX/DXILPrepare.cpp
+++ b/llvm/lib/Target/DirectX/DXILPrepare.cpp
@@ -11,7 +11,6 @@
 /// Language (DXIL).
 //===----------------------------------------------------------------------===//
 
-#include "DXILMetadata.h"
 #include "DXILResourceAnalysis.h"
 #include "DXILShaderFlags.h"
 #include "DirectX.h"
@@ -174,8 +173,9 @@ class DXILPrepareModule : public ModulePass {
         AttrMask.addAttribute(I);
     }
 
-    dxil::ValidatorVersionMD ValVerMD(M);
-    VersionTuple ValVer = ValVerMD.getAsVersionTuple();
+    const dxil::ModuleMetadataInfo MetadataInfo =
+        getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
+    VersionTuple ValVer = MetadataInfo.ValidatorVersion;
     bool SkipValidation = ValVer.getMajor() == 0 && ValVer.getMinor() == 0;
 
     for (auto &F : M.functions()) {
@@ -247,10 +247,8 @@ class DXILPrepareModule : public ModulePass {
 
   DXILPrepareModule() : ModulePass(ID) {}
   void getAnalysisUsage(AnalysisUsage &AU) const override {
-    AU.addPreserved<ShaderFlagsAnalysisWrapper>();
-    AU.addPreserved<DXILResourceMDWrapper>();
-    AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
-    AU.addPreserved<DXILResourceWrapperPass>();
+    AU.setPreservesAll();
+    AU.addRequired<DXILMetadataAnalysisWrapperPass>();
   }
   static char ID; // Pass identification.
 };
@@ -260,6 +258,7 @@ char DXILPrepareModule::ID = 0;
 
 INITIALIZE_PASS_BEGIN(DXILPrepareModule, DEBUG_TYPE, "DXIL Prepare Module",
                       false, false)
+INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass)
 INITIALIZE_PASS_END(DXILPrepareModule, DEBUG_TYPE, "DXIL Prepare Module", false,
                     false)
 
diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index 11cd9df1d1dc42..3c7a9168a59257 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -7,20 +7,24 @@
 //===----------------------------------------------------------------------===//
 
 #include "DXILTranslateMetadata.h"
-#include "DXILMetadata.h"
 #include "DXILResource.h"
 #include "DXILResourceAnalysis.h"
 #include "DXILShaderFlags.h"
 #include "DirectX.h"
-#include "llvm/ADT/StringSet.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/Analysis/DXILMetadataAnalysis.h"
 #include "llvm/Analysis/DXILResource.h"
 #include "llvm/IR/Constants.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/Module.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
+#include "llvm/Support/VersionTuple.h"
 #include "llvm/TargetParser/Triple.h"
+#include <cstdint>
 
 using namespace llvm;
 using namespace llvm::dxil;
@@ -65,18 +69,273 @@ static void emitResourceMetadata(Module &M, const DXILResourceMap &DRM,
       MDNode::get(M.getContext(), {SRVMD, UAVMD, CBufMD, SmpMD}));
 }
 
+static StringRef getShortShaderStage(Triple::EnvironmentType Env) {
+  switch (Env) {
+  case Triple::Pixel:
+    return "ps";
+  case Triple::Vertex:
+    return "vs";
+  case Triple::Geometry:
+    return "gs";
+  case Triple::Hull:
+    return "hs";
+  case Triple::Domain:
+    return "ds";
+  case Triple::Compute:
+    return "cs";
+  case Triple::Library:
+    return "lib";
+  case Triple::Mesh:
+    return "ms";
+  case Triple::Amplification:
+    return "as";
+  default:
+    break;
+  }
+  llvm_unreachable("Unsupported environment for DXIL generation.");
+  return "";
+}
+
+static uint32_t getShaderStage(Triple::EnvironmentType Env) {
+  return (uint32_t)Env - (uint32_t)llvm::Triple::Pixel;
+}
+
+struct ShaderEntryMDInfo : EntryProperties {
+
+  enum EntryPropsTag {
+    ShaderFlagsTag = 0,
+    GSStateTag,
+    DSStateTag,
+    HSStateTag,
+    NumThreadsTag,
+    AutoBindingSpaceTag,
+    RayPayloadSizeTag,
+    RayAttribSizeTag,
+    ShaderKindTag,
+    MSStateTag,
+    ASStateTag,
+    WaveSizeTag,
+    EntryRootSigTag,
+  };
+
+  ShaderEntryMDInfo(EntryProperties &EP, LLVMContext &C,
+                    Triple::EnvironmentType SP, MDTuple *MDR = nullptr,
+                    uint64_t ShaderFlags = 0)
+      : EntryProperties(EP), Ctx(C), EntryShaderFlags(ShaderFlags),
+        MDResources(MDR), ShaderProfile(SP) {};
+
+  MDTuple *getAsMetadata() {
+    MDTuple *Properties = constructEntryPropMetadata();
+    // FIXME: Add support to construct Signatures
+    // See https://github.com/llvm/llvm-project/issues/57928
+    MDTuple *Signatures = nullptr;
+    return constructEntryMetadata(Signatures, MDResources, Properties);
+  }
+
+private:
+  LLVMContext &Ctx;
+  // Shader Flags for the Entry - from ShadeFLagsAnalysis pass
+  uint64_t EntryShaderFlags{0};
+  MDTuple *MDResources{nullptr};
+  Triple::EnvironmentType ShaderProfile{
+      Triple::EnvironmentType::UnknownEnvironment};
+  // Each entry point metadata record specifies:
+  //  * reference to the entry point function global symbol
+  //  * unmangled name
+  //  * list of signatures
+  //  * list of resources
+  //  * list of tag-value pairs of shader capabilities and other properties
+
+  MDTuple *constructEntryMetadata(MDTuple *Signatures, MDTuple *Resources,
+                                  MDTuple *Properties) {
+    Metadata *MDVals[5];
+    MDVals[0] =
+        Entry ? ValueAsMetadata::get(const_cast<Function *>(Entry)) : nullptr;
+    MDVals[1] = MDString::get(Ctx, Entry ? Entry->getName() : "");
+    MDVals[2] = Signatures;
+    MDVals[3] = Resources;
+    MDVals[4] = Properties;
+    return MDNode::get(Ctx, MDVals);
+  }
+
+  SmallVector<Metadata *> getTagValueAsMetadata(EntryPropsTag Tag,
+                                                uint64_t Value) {
+    SmallVector<Metadata *> MDVals;
+    MDVals.emplace_back(
+        ConstantAsMetadata::get(ConstantInt::get(Type::getInt32Ty(Ctx), Tag)));
+    switch (Tag) {
+    case ShaderFlagsTag:
+      MDVals.emplace_back(ConstantAsMetadata::get(
+          ConstantInt::get(Type::getInt64Ty(Ctx), Value)));
+      break;
+    case ShaderKindTag:
+      MDVals.emplace_back(ConstantAsMetadata::get(
+          ConstantInt::get(Type::getInt32Ty(Ctx), Value)));
+      break;
+    default:
+      assert(false && "NYI: Unhandled entry property tag");
+    }
+    return MDVals;
+  }
+
+  MDTuple *constructEntryPropMetadata() {
+    SmallVector<Metadata *> MDVals;
+    if (EntryShaderFlags != 0)
+      MDVals.append(getTagValueAsMetadata(ShaderFlagsTag, EntryShaderFlags));
+
+    if (Entry != nullptr) {
+      // FIXME: support more props.
+      // See https://github.com/llvm/llvm-project/issues/57948.
+      // Add shader kind for lib entries.
+      if (ShaderProfile == Triple::EnvironmentType::Library &&
+          ShaderStage != Triple::EnvironmentType::Library)
+        MDVals.append(
+            getTagValueAsMetadata(ShaderKindTag, getShaderStage(ShaderStage)));
+
+      if (ShaderStage == Triple::EnvironmentType::Compute) {
+        MDVals.emplace_back(ConstantAsMetadata::get(
+            ConstantInt::get(Type::getInt32Ty(Ctx), NumThreadsTag)));
+        std::vector<Metadata *> NumThreadVals;
+        NumThreadVals.emplace_back(ConstantAsMetadata::get(
+            ConstantInt::get(Type::getInt32Ty(Ctx), NumThreadsX)));
+        NumThreadVals.emplace_back(ConstantAsMetadata::get(
+            ConstantInt::get(Type::getInt32Ty(Ctx), NumThreadsY)));
+        NumThreadVals.emplace_back(ConstantAsMetadata::get(
+            ConstantInt::get(Type::getInt32Ty(Ctx), NumThreadsZ)));
+        MDVals.emplace_back(MDNode::get(Ctx, NumThreadVals));
+      }
+    }
+    if (MDVals.empty())
+      return nullptr;
+    return MDNode::get(Ctx, MDVals);
+  }
+};
+
+static void createEntryMD(Module &M, const uint64_t ShaderFlags,
+                          const dxil::ModuleMetadataInfo &MDAnalysisInfo) {
+  auto &Ctx = M.getContext();
+  // FIXME: generate metadata for resource.
+  MDTuple *MDResources = nullptr;
+  if (auto *NamedResources = M.getNamedMetadata("dx.resources"))
+    MDResources = dyn_cast<MDTuple>(NamedResources->getOperand(0));
+
+  std::vector<MDNode *> EntryFnMDNodes;
+  switch (MDAnalysisInfo.ShaderProfile) {
+  case Triple::EnvironmentType::Library: {
+    // Library has an entry metadata with resource table metadata and all other
+    // MDNodes as null.
+    EntryProperties EP{};
+    // FIXME: ShaderFlagsAnalysis pass needs to collect and provide ShaderFlags
+    // for each entry function. Currently, ShaderFlags value provided by
+    // ShaderFlagsAnalysis pass is created by walking *all* the function
+    // instructions of the module. Is it is correct to use this value for
+    // metadata of the empty library entry?
+    ShaderEntryMDInfo EmptyFunEntryProps(EP, Ctx, MDAnalysisInfo.ShaderProfile,
+                                         MDResources, ShaderFlags);
+    MDTuple *EmptyMDT = EmptyFunEntryProps.getAsMetadata();
+    EntryFnMDNodes.emplace_back(EmptyMDT);
+
+    for (auto EntryProp : MDAnalysisInfo.EntryPropertyVec) {
+      // FIXME: ShaderFlagsAnalysis pass needs to collect and provide
+      // ShaderFlags for each entry function. For now, assume shader flags value
+      // of entry functions being compiled for lib_* shader profile viz.,
+      // EntryPro.Entry is 0.
+      ShaderEntryMDInfo SEP(EntryProp, Ctx, MDAnalysisInfo.ShaderProfile,
+                            nullptr, 0);
+      MDTuple *EmptyMDT = SEP.getAsMetadata();
+      EntryFnMDNodes.emplace_back(EmptyMDT);
+    }
+  } break;
+  case Triple::EnvironmentType::Compute: {
+    size_t NumEntries = MDAnalysisInfo.EntryPropertyVec.size();
+    if (NumEntries > 0) {
+      assert(NumEntries == 1 &&
+             "Compute shader: One and only one entry expected");
+      EntryProperties EntryProp = MDAnalysisInfo.EntryPropertyVec[0];
+      // ShaderFlagsAnalysis pass needs to collect and provide ShaderFlags for
+      // each entry function. Currently, even though the ShaderFlags value
+      // provided by ShaderFlagsAnalysis pass is created by walking all the
+      // function instructions of the module, it is sufficient to since there is
+      // only one entry function in the module.
+      ShaderEntryMDInfo SEP(EntryProp, Ctx, MDAnalysisInfo.ShaderProfile,
+                            MDResources, ShaderFlags);
+      MDTuple *EmptyMDT = SEP.getAsMetadata();
+      EntryFnMDNodes.emplace_back(EmptyMDT);
+    }
+    break;
+  }
+  case Triple::EnvironmentType::Amplification:
+  case Triple::EnvironmentType::Mesh:
+  case Triple::EnvironmentType::Vertex:
+  case Triple::EnvironmentType::Hull:
+  case Triple::EnvironmentType::Domain:
+  case Triple::EnvironmentType::Geometry:
+  case Triple::EnvironmentType::Pixel: {
+    size_t NumEntries = MDAnalysisInfo.EntryPropertyVec.size();
+    if (NumEntries > 0) {
+      assert(NumEntries == 1 && "non-lib profiles should only have one entry");
+      EntryProperties EntryProp = MDAnalysisInfo.EntryPropertyVec[0];
+      // ShaderFlagsAnalysis pass needs to collect and provide ShaderFlags for
+      // each entry function. Currently, even though the ShaderFlags value
+      // provided by ShaderFlagsAnalysis pass is created by walking all the
+      // function instructions of the module, it is sufficient to since there is
+      // only one entry function in the module.
+      ShaderEntryMDInfo SEP(EntryProp, Ctx, MDAnalysisInfo.ShaderProfile,
+                            MDResources, ShaderFlags);
+      MDTuple *EmptyMDT = SEP.getAsMetadata();
+      EntryFnMDNodes.emplace_back(EmptyMDT);
+    }
+  } break;
+  default:
+    assert(0 && "invalid profile");
+    break;
+  }
+
+  NamedMDNode *EntryPointsNamedMD =
+      M.getOrInsertNamedMetadata("dx.entryPoints");
+  for (auto *Entry : EntryFnMDNodes)
+    EntryPointsNamedMD->addOperand(Entry);
+}
+
 static void translateMetadata(Module &M, const DXILResourceMap &DRM,
                               const dxil::Resources &MDResources,
-                              const ComputedShaderFlags &ShaderFlags) {
-  dxil::ValidatorVersionMD ValVerMD(M);
-  if (ValVerMD.isEmpty())
-    ValVerMD.update(VersionTuple(1, 0));
-  dxil::createShaderModelMD(M);
-  dxil::createDXILVersionMD(M);
+                              const ComputedShaderFlags &ShaderFlags,
+                              const dxil::ModuleMetadataInfo &MDAnalysisInfo) {
+  LLVMContext &Ctx = M.getContext();
+  IRBuilder<> IRB(Ctx);
+  if (MDAnalysisInfo.ValidatorVersion.empty()) {
+    // Module has no metadata node signifying valid validator version.
+    // Create metadata dx.valver node with version value of 1.0
+    const VersionTuple DefaultValidatorVer{1, 0};
+    Metadata *MDVals[2];
+    MDVals[0] =
+        ConstantAsMetadata::get(IRB.getInt32(DefaultValidatorVer.getMajor()));
+    MDVals[1] = ConstantAsMetadata::get(
+        IRB.getInt32(DefaultValidatorVer.getMinor().value_or(0)));
+    NamedMDNode *ValVerNode = M.getOrInsertNamedMetadata("dx.valver");
+    ValVerNode->addOperand(MDNode::get(Ctx, MDVals));
+  }
+
+  Metadata *SMVals[3];
+  VersionTuple SM = MDAnalysisInfo.ShaderModelVersion;
+  SMVals[0] =
+      MDString::get(Ctx, getShortShaderStage(MDAnalysisInfo.ShaderProfile));
+  SMVals[1] = ConstantAsMetadata::get(IRB.getInt32(SM.getMajor()));
+  SMVals[2] = ConstantAsMetadata::get(IRB.getInt32(SM.getMinor().value_or(0)));
+  NamedMDNode *SMMDNode = M.getOrInsertNamedMetadata("dx.shaderModel");
+  SMMDNode->addOperand(MDNode::get(Ctx, SMVals));
+
+  VersionTuple DXILVer = MDAnalysisInfo.DXILVersion;
+  Metadata *DXILVals[2];
+  DXILVals[0] = ConstantAsMetadata::get(IRB.getInt32(DXILVer.getMajor()));
+  DXILVals[1] =
+      ConstantAsMetadata::get(IRB.getInt32(DXILVer.getMinor().value_or(0)));
+  NamedMDNode *DXILVerMDNode = M.getOrInsertNamedMetadata("dx.version");
+  DXILVerMDNode->addOperand(MDNode::get(Ctx, DXILVals));
 
   emitResourceMetadata(M, DRM, MDResources);
 
-  dxil::createEntryMD(M, static_cast<uint64_t>(ShaderFlags));
+  createEntryMD(M, static_cast<uint64_t>(ShaderFlags), MDAnalysisInfo);
 }
 
 PreservedAnalyses DXILTranslateMetadata::run(Module &M,
@@ -85,8 +344,10 @@ PreservedAnalyses DXILTranslateMetadata::run(Module &M,
   const dxil::Resources &MDResources = MAM.getResult<DXILResourceMDAnalysis>(M);
   const ComputedShaderFlags &ShaderFlags =
       MAM.getResult<ShaderFlagsAnalysis>(M);
+  const dxil::ModuleMetadataInfo MetadataInfo =
+      MAM.getResult<DXILMetadataAnalysis>(M);
 
-  translateMetadata(M, DRM, MDResources, ShaderFlags);
+  translateMetadata(M, DRM, MDResources, ShaderFlags, MetadataInfo);
 
   return PreservedAnalyses::all();
 }
@@ -114,8 +375,10 @@ class DXILTranslateMetadataLegacy : public ModulePass {
         getAnalysis<DXILResourceMDWrapper>().getDXILResource();
     const ComputedShaderFlags &ShaderFlags =
         getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags();
+    dxil::ModuleMetadataInfo MetadataInfo =
+        getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
 
-    translateMetadata(M, DRM, MDResources, ShaderFlags);
+    translateMetadata(M, DRM, MDResources, ShaderFlags, MetadataInfo);
     return true;
   }
 };
@@ -133,5 +396,6 @@ INITIALIZE_PASS_BEGIN(DXILTranslateMetadataLegacy, "dxil-translate-metadata",
 INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapperPass)
 INITIALIZE_PASS_DEPENDENCY(DXILResourceMDWrapper)
 INITIALIZE_PASS_DEPENDENCY(ShaderFlagsAnalysisWrapper)
+INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass)
 INITIALIZE_PASS_END(DXILTranslateMetadataLegacy, "dxil-translate-metadata",
                     "DXIL Translate Metadata", false, false)
diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
index 606022a9835f04..1ca75661f73d15 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
@@ -53,6 +53,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
   initializeDXContainerGlobalsPass(*PR);
   initializeDXILOpLoweringLegacyPass(*PR);
   initializeDXILTranslateMetadataLegacyPass(*PR);
+  initializeDXILMetadataAnalysisWrapperPassPass(*PR);
   initializeDXILResourceMDWrapperPass(*PR);
   initializeShaderFlagsAnalysisWrapperPass(*PR);
   initializeDXILFinalizeLinkageLegacyPass(*PR);
diff --git a/llvm/test/CodeGen/DirectX/legalize-module-flags.ll b/llvm/test/CodeGen/DirectX/legalize-module-flags.ll
index 1483a87e0b4bd3..6c29deabc2aa3b 100644
--- a/llvm/test/CodeGen/DirectX/legalize-module-flags.ll
+++ b/llvm/test/CodeGen/DirectX/legalize-module-flags.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-prepare < %s | FileCheck %s
+; RUN: opt -S -dxil-prepare -mtriple=dxil-unknown-shadermodel6.0-compute %s | FileCheck %s
 
 ; Make sure behavior flag > 6 is fixed.
 ; CHECK: !{i32 2, !"frame-pointer", i32 2}
diff --git a/llvm/test/CodeGen/DirectX/legalize-module-flags2.ll b/llvm/test/CodeGen/DirectX/legalize-module-flags2.ll
index e1803b4672684f..244ec8d54e131e 100644
--- a/llvm/test/CodeGen/DirectX/legalize-module-flags2.ll
+++ b/llvm/test/CodeGen/DirectX/legalize-module-flags2.ll
@@ -1,4 +1,4 @@
-; RUN: opt -S -dxil-prepare < %s | FileCheck %s
+; RUN: opt -S -dxil-prepare -mtriple=dxil-unknown-shadermodel6.0-library %s | FileCheck %s
 
 ; CHECK: define void @main()
 ; Make sure behavior flag > 6 is fixed.
diff --git a/llvm/test/CodeGen/DirectX/strip-call-attrs.ll b/llvm/test/CodeGen/DirectX/strip-call-attrs.ll
index f530e12fa7e580..e232ab24d69f34 100644
--- a/llvm/test/CodeGen/DirectX/strip-call-attrs.ll
+++ b/llvm/test/CodeGen/DirectX/strip-call-attrs.ll
@@ -1,6 +1,6 @@
 
 ; RUN: opt -S -dxil-prepare < %s | FileCheck %s
-target triple = "dxil-unknown-unknown"
+target triple = "dxil-unknown-shadermodel6.0-library"
 
 @f = internal unnamed_addr global float 0.000000e+00, align 4
 @llvm.global_ctors = appending global [1 x { i32, ptr, ptr }] [{ i32, ptr, ptr } { i32 65535, ptr @_GLOBAL__sub_I_static_global.hlsl, ptr null }]
diff --git a/llvm/test/CodeGen/DirectX/typed_ptr.ll b/llvm/test/CodeGen/DirectX/typed_ptr.ll
index 5453e87651dd72..355c4f13d0e4c6 100644
--- a/llvm/test/CodeGen/DirectX/typed_ptr.ll
+++ b/llvm/test/CodeGen/DirectX/typed_ptr.ll
@@ -1,6 +1,6 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 3
 ; RUN: opt -S -dxil-prepare < %s | FileCheck %s
-target triple = "dxil-unknown-unknown"
+target triple = "dxil-unknown-shadermodel6.0-compute"
 
 @gs = external addrspace(3) global [20 x [6 x float]], align 4
 

>From dd0412c1de9c8d60fe8a7db5617fcf52a4e38141 Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Tue, 10 Sep 2024 11:21:16 -0400
Subject: [PATCH 2/8] Fix mis-named variable that was not edited appropriately
 after a copy-paste.

Add necessary function attributes to tests and remove
unnecessary check for mumber of entries for non-library
shaders in createEntryMD()
---
 .../Target/DirectX/DXILTranslateMetadata.cpp  | 54 +++++++++----------
 llvm/test/CodeGen/DirectX/CreateHandle.ll     |  4 +-
 .../DirectX/CreateHandleFromBinding.ll        |  4 +-
 3 files changed, 29 insertions(+), 33 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index 3c7a9168a59257..95cc37a9190f3a 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -242,26 +242,24 @@ static void createEntryMD(Module &M, const uint64_t ShaderFlags,
       // EntryPro.Entry is 0.
       ShaderEntryMDInfo SEP(EntryProp, Ctx, MDAnalysisInfo.ShaderProfile,
                             nullptr, 0);
-      MDTuple *EmptyMDT = SEP.getAsMetadata();
-      EntryFnMDNodes.emplace_back(EmptyMDT);
+      MDTuple *MDT = SEP.getAsMetadata();
+      EntryFnMDNodes.emplace_back(MDT);
     }
   } break;
   case Triple::EnvironmentType::Compute: {
     size_t NumEntries = MDAnalysisInfo.EntryPropertyVec.size();
-    if (NumEntries > 0) {
-      assert(NumEntries == 1 &&
-             "Compute shader: One and only one entry expected");
-      EntryProperties EntryProp = MDAnalysisInfo.EntryPropertyVec[0];
-      // ShaderFlagsAnalysis pass needs to collect and provide ShaderFlags for
-      // each entry function. Currently, even though the ShaderFlags value
-      // provided by ShaderFlagsAnalysis pass is created by walking all the
-      // function instructions of the module, it is sufficient to since there is
-      // only one entry function in the module.
-      ShaderEntryMDInfo SEP(EntryProp, Ctx, MDAnalysisInfo.ShaderProfile,
-                            MDResources, ShaderFlags);
-      MDTuple *EmptyMDT = SEP.getAsMetadata();
-      EntryFnMDNodes.emplace_back(EmptyMDT);
-    }
+    assert(NumEntries == 1 &&
+           "Compute shader: One and only one entry expected");
+    EntryProperties EntryProp = MDAnalysisInfo.EntryPropertyVec[0];
+    // ShaderFlagsAnalysis pass needs to collect and provide ShaderFlags for
+    // each entry function. Currently, even though the ShaderFlags value
+    // provided by ShaderFlagsAnalysis pass is created by walking all the
+    // function instructions of the module, it is sufficient to since there is
+    // only one entry function in the module.
+    ShaderEntryMDInfo SEP(EntryProp, Ctx, MDAnalysisInfo.ShaderProfile,
+                          MDResources, ShaderFlags);
+    MDTuple *MDT = SEP.getAsMetadata();
+    EntryFnMDNodes.emplace_back(MDT);
     break;
   }
   case Triple::EnvironmentType::Amplification:
@@ -272,19 +270,17 @@ static void createEntryMD(Module &M, const uint64_t ShaderFlags,
   case Triple::EnvironmentType::Geometry:
   case Triple::EnvironmentType::Pixel: {
     size_t NumEntries = MDAnalysisInfo.EntryPropertyVec.size();
-    if (NumEntries > 0) {
-      assert(NumEntries == 1 && "non-lib profiles should only have one entry");
-      EntryProperties EntryProp = MDAnalysisInfo.EntryPropertyVec[0];
-      // ShaderFlagsAnalysis pass needs to collect and provide ShaderFlags for
-      // each entry function. Currently, even though the ShaderFlags value
-      // provided by ShaderFlagsAnalysis pass is created by walking all the
-      // function instructions of the module, it is sufficient to since there is
-      // only one entry function in the module.
-      ShaderEntryMDInfo SEP(EntryProp, Ctx, MDAnalysisInfo.ShaderProfile,
-                            MDResources, ShaderFlags);
-      MDTuple *EmptyMDT = SEP.getAsMetadata();
-      EntryFnMDNodes.emplace_back(EmptyMDT);
-    }
+    assert(NumEntries == 1 && "non-lib profiles should only have one entry");
+    EntryProperties EntryProp = MDAnalysisInfo.EntryPropertyVec[0];
+    // ShaderFlagsAnalysis pass needs to collect and provide ShaderFlags for
+    // each entry function. Currently, even though the ShaderFlags value
+    // provided by ShaderFlagsAnalysis pass is created by walking all the
+    // function instructions of the module, it is sufficient to since there is
+    // only one entry function in the module.
+    ShaderEntryMDInfo SEP(EntryProp, Ctx, MDAnalysisInfo.ShaderProfile,
+                          MDResources, ShaderFlags);
+    MDTuple *MDT = SEP.getAsMetadata();
+    EntryFnMDNodes.emplace_back(MDT);
   } break;
   default:
     assert(0 && "invalid profile");
diff --git a/llvm/test/CodeGen/DirectX/CreateHandle.ll b/llvm/test/CodeGen/DirectX/CreateHandle.ll
index 40b3b2c7122722..4653baf8a3b21a 100644
--- a/llvm/test/CodeGen/DirectX/CreateHandle.ll
+++ b/llvm/test/CodeGen/DirectX/CreateHandle.ll
@@ -14,7 +14,7 @@ target triple = "dxil-pc-shadermodel6.0-compute"
 
 declare i32 @some_val();
 
-define void @test_buffers() {
+define void @test_buffers() #0 {
   ; RWBuffer<float4> Buf : register(u5, space3)
   %typed0 = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
               @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_1_0_0(
@@ -68,4 +68,4 @@ define void @test_buffers() {
 ; CHECK-DAG: [[SRVMD]] = !{!{{[0-9]+}}, !{{[0-9]+}}, !{{[0-9]+}}, !{{[0-9]+}}}
 ; CHECK-DAG: [[UAVMD]] = !{!{{[0-9]+}}, !{{[0-9]+}}}
 
-attributes #0 = { nocallback nofree nosync nounwind willreturn memory(none) }
+attributes #0 = { nocallback nofree nosync nounwind willreturn memory(none) "hlsl.numthreads"="1,2,1" "hlsl.shader"="compute"}
diff --git a/llvm/test/CodeGen/DirectX/CreateHandleFromBinding.ll b/llvm/test/CodeGen/DirectX/CreateHandleFromBinding.ll
index dbdd2e61df7a3b..a082e0c197723c 100644
--- a/llvm/test/CodeGen/DirectX/CreateHandleFromBinding.ll
+++ b/llvm/test/CodeGen/DirectX/CreateHandleFromBinding.ll
@@ -14,7 +14,7 @@ target triple = "dxil-pc-shadermodel6.6-compute"
 
 declare i32 @some_val();
 
-define void @test_bindings() {
+define void @test_bindings() #0 {
   ; RWBuffer<float4> Buf : register(u5, space3)
   %typed0 = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
               @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_1_0_0(
@@ -73,4 +73,4 @@ define void @test_bindings() {
 ; CHECK-DAG: [[SRVMD]] = !{!{{[0-9]+}}, !{{[0-9]+}}, !{{[0-9]+}}, !{{[0-9]+}}}
 ; CHECK-DAG: [[UAVMD]] = !{!{{[0-9]+}}, !{{[0-9]+}}}
 
-attributes #0 = { nocallback nofree nosync nounwind willreturn memory(none) }
+attributes #0 = { nocallback nofree nosync nounwind willreturn memory(none) "hlsl.numthreads"="1,2,1" "hlsl.shader"="compute"}

>From 42b57a698e7da33808a404196a3828a426ad2b37 Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Tue, 10 Sep 2024 13:24:18 -0400
Subject: [PATCH 3/8] Add a test to verify correct metadata creation for
 multiple shader entries during library profile compilation.

---
 .../CodeGen/DirectX/Metadata/lib-entries.ll   | 39 +++++++++++++++++++
 1 file changed, 39 insertions(+)
 create mode 100644 llvm/test/CodeGen/DirectX/Metadata/lib-entries.ll

diff --git a/llvm/test/CodeGen/DirectX/Metadata/lib-entries.ll b/llvm/test/CodeGen/DirectX/Metadata/lib-entries.ll
new file mode 100644
index 00000000000000..3ce15eea6d437c
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/Metadata/lib-entries.ll
@@ -0,0 +1,39 @@
+; RUN: opt -S  -S -dxil-translate-metadata %s 2>&1 | FileCheck %s
+target triple = "dxil-pc-shadermodel6.8-library"
+
+
+; CHECK: !dx.valver = !{![[DXVALVER:[0-9]+]]}
+; CHECK: !dx.shaderModel = !{![[SM:[0-9]+]]}
+; CHECK: !dx.version = !{![[DXVER:[0-9]+]]}
+; CHECK: !dx.entryPoints = !{![[LIB:[0-9]+]], ![[AS:[0-9]+]], ![[MS:[0-9]+]], ![[CS:[0-9]+]]}
+
+; CHECK: ![[DXVALVER]] = !{i32 1, i32 0}
+; CHECK: ![[SM]] = !{!"lib", i32 6, i32 8}
+; CHECK: ![[DXVER]] = !{i32 1, i32 8}
+; CHECK: ![[LIB]] = !{null, !"", null, null, null}
+; CHECK: ![[AS]] = !{ptr @entry_as, !"entry_as", null, null, ![[AS_SF:[0-9]*]]}
+; CHECK: ![[AS_SF]] =  !{i32 8, i32 14}
+; CHECK: ![[MS]] = !{ptr @entry_ms, !"entry_ms", null, null, ![[MS_SF:[0-9]*]]}
+; CHECK: ![[MS_SF]] =  !{i32 8, i32 13}
+; CHECK: ![[CS]] = !{ptr @entry_cs, !"entry_cs", null, null, ![[CS_SF:[0-9]*]]}
+; CHECK: ![[CS_SF]] =  !{i32 8, i32 5, i32 4, ![[CS_NT:[0-9]*]]}
+; CHECK: !{i32 1, i32 2, i32 1}
+
+define void @entry_as() #0 {
+entry:
+  ret void
+}
+
+define i32 @entry_ms(i32 %a) #1 {
+entry:
+  ret i32 %a
+}
+
+define float @entry_cs(float %f) #3 {
+entry:
+  ret float %f
+}
+
+attributes #0 = { noinline nounwind "hlsl.shader"="amplification" }
+attributes #1 = { noinline nounwind "hlsl.shader"="mesh" }
+attributes #3 = { noinline nounwind "hlsl.numthreads"="1,2,1" "hlsl.shader"="compute" }

>From 45da9fb2025c4a44c0c75a7989456184e9e2f3cf Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Wed, 11 Sep 2024 10:31:56 -0400
Subject: [PATCH 4/8] Address PR feedback

- Delete derived struct ShaderEntryMDInfo and move the functionlaity
  in its methods as static functions.
- Add class DiagnosticInfoModuleFormat derived from DiagnosticInfo for
  diagnostics reporting in TranslateMetadata pass
- Consume Resource metadata information constructed by emitResourceMetadata()
- Move generation of metadata for validator version, Shader Model
  version, DXIL Version into separate functions.
- Changes to accept input shader modules with no entry functions
- Update/emit named metadata dx.valver only if Metadata Analysis info
  contains the information; it is not created to with a default value,
  if it is not.
---
 .../Target/DirectX/DXILTranslateMetadata.cpp  | 426 +++++++++---------
 .../Target/DirectX/DirectXTargetMachine.cpp   |   2 -
 llvm/test/CodeGen/DirectX/CreateHandle.ll     |   4 +-
 .../DirectX/CreateHandleFromBinding.ll        |   4 +-
 .../CodeGen/DirectX/Metadata/lib-entries.ll   |   2 -
 .../Metadata/multiple-entries-cs-error.ll     |  23 +
 6 files changed, 243 insertions(+), 218 deletions(-)
 create mode 100644 llvm/test/CodeGen/DirectX/Metadata/multiple-entries-cs-error.ll

diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index 95cc37a9190f3a..f143de602b6dda 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -15,6 +15,8 @@
 #include "llvm/Analysis/DXILMetadataAnalysis.h"
 #include "llvm/Analysis/DXILResource.h"
 #include "llvm/IR/Constants.h"
+#include "llvm/IR/DiagnosticInfo.h"
+#include "llvm/IR/DiagnosticPrinter.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/LLVMContext.h"
@@ -29,8 +31,38 @@
 using namespace llvm;
 using namespace llvm::dxil;
 
-static void emitResourceMetadata(Module &M, const DXILResourceMap &DRM,
-                                 const dxil::Resources &MDResources) {
+/// A simple Wrapper DiagnosticInfo that generates Module-level diagnostic
+class DiagnosticInfoModuleFormat : public DiagnosticInfo {
+private:
+  Twine Msg;
+  const Module &Mod;
+
+public:
+  /// \p M is the module for which the diagnostic is being emitted. \p Msg is
+  /// the message to show. Note that this class does not copy this message, so
+  /// this reference must be valid for the whole life time of the diagnostic.
+  DiagnosticInfoModuleFormat(const Module &M, const Twine &Msg,
+                             DiagnosticSeverity Severity = DS_Error)
+      : DiagnosticInfo(DK_Unsupported, Severity), Msg(Msg), Mod(M) {}
+
+  static bool classof(const DiagnosticInfo *DI) {
+    return DI->getKind() == DK_Unsupported;
+  }
+
+  const Twine &getMessage() const { return Msg; }
+
+  void print(DiagnosticPrinter &DP) const override {
+    std::string Str;
+    raw_string_ostream OS(Str);
+
+    OS << Mod.getName() << ": " << Msg << '\n';
+    OS.flush();
+    DP << Str;
+  }
+};
+
+static NamedMDNode *emitResourceMetadata(Module &M, const DXILResourceMap &DRM,
+                                         const dxil::Resources &MDResources) {
   LLVMContext &Context = M.getContext();
 
   SmallVector<Metadata *> SRVs, UAVs, CBufs, Smps;
@@ -62,11 +94,13 @@ static void emitResourceMetadata(Module &M, const DXILResourceMap &DRM,
   }
 
   if (!HasResources)
-    return;
+    return nullptr;
 
   NamedMDNode *ResourceMD = M.getOrInsertNamedMetadata("dx.resources");
   ResourceMD->addOperand(
       MDNode::get(M.getContext(), {SRVMD, UAVMD, CBufMD, SmpMD}));
+
+  return ResourceMD;
 }
 
 static StringRef getShortShaderStage(Triple::EnvironmentType Env) {
@@ -93,245 +127,218 @@ static StringRef getShortShaderStage(Triple::EnvironmentType Env) {
     break;
   }
   llvm_unreachable("Unsupported environment for DXIL generation.");
-  return "";
 }
 
 static uint32_t getShaderStage(Triple::EnvironmentType Env) {
   return (uint32_t)Env - (uint32_t)llvm::Triple::Pixel;
 }
 
-struct ShaderEntryMDInfo : EntryProperties {
-
-  enum EntryPropsTag {
-    ShaderFlagsTag = 0,
-    GSStateTag,
-    DSStateTag,
-    HSStateTag,
-    NumThreadsTag,
-    AutoBindingSpaceTag,
-    RayPayloadSizeTag,
-    RayAttribSizeTag,
-    ShaderKindTag,
-    MSStateTag,
-    ASStateTag,
-    WaveSizeTag,
-    EntryRootSigTag,
-  };
-
-  ShaderEntryMDInfo(EntryProperties &EP, LLVMContext &C,
-                    Triple::EnvironmentType SP, MDTuple *MDR = nullptr,
-                    uint64_t ShaderFlags = 0)
-      : EntryProperties(EP), Ctx(C), EntryShaderFlags(ShaderFlags),
-        MDResources(MDR), ShaderProfile(SP) {};
-
-  MDTuple *getAsMetadata() {
-    MDTuple *Properties = constructEntryPropMetadata();
-    // FIXME: Add support to construct Signatures
-    // See https://github.com/llvm/llvm-project/issues/57928
-    MDTuple *Signatures = nullptr;
-    return constructEntryMetadata(Signatures, MDResources, Properties);
-  }
+namespace {
+enum EntryPropsTag {
+  ShaderFlagsTag = 0,
+  GSStateTag,
+  DSStateTag,
+  HSStateTag,
+  NumThreadsTag,
+  AutoBindingSpaceTag,
+  RayPayloadSizeTag,
+  RayAttribSizeTag,
+  ShaderKindTag,
+  MSStateTag,
+  ASStateTag,
+  WaveSizeTag,
+  EntryRootSigTag,
+};
+} // namespace
 
-private:
-  LLVMContext &Ctx;
-  // Shader Flags for the Entry - from ShadeFLagsAnalysis pass
-  uint64_t EntryShaderFlags{0};
-  MDTuple *MDResources{nullptr};
-  Triple::EnvironmentType ShaderProfile{
-      Triple::EnvironmentType::UnknownEnvironment};
-  // Each entry point metadata record specifies:
-  //  * reference to the entry point function global symbol
-  //  * unmangled name
-  //  * list of signatures
-  //  * list of resources
-  //  * list of tag-value pairs of shader capabilities and other properties
-
-  MDTuple *constructEntryMetadata(MDTuple *Signatures, MDTuple *Resources,
-                                  MDTuple *Properties) {
-    Metadata *MDVals[5];
-    MDVals[0] =
-        Entry ? ValueAsMetadata::get(const_cast<Function *>(Entry)) : nullptr;
-    MDVals[1] = MDString::get(Ctx, Entry ? Entry->getName() : "");
-    MDVals[2] = Signatures;
-    MDVals[3] = Resources;
-    MDVals[4] = Properties;
-    return MDNode::get(Ctx, MDVals);
+static SmallVector<Metadata *>
+getTagValueAsMetadata(EntryPropsTag Tag, uint64_t Value, LLVMContext &Ctx) {
+  SmallVector<Metadata *> MDVals;
+  MDVals.emplace_back(
+      ConstantAsMetadata::get(ConstantInt::get(Type::getInt32Ty(Ctx), Tag)));
+  switch (Tag) {
+  case ShaderFlagsTag:
+    MDVals.emplace_back(ConstantAsMetadata::get(
+        ConstantInt::get(Type::getInt64Ty(Ctx), Value)));
+    break;
+  case ShaderKindTag:
+    MDVals.emplace_back(ConstantAsMetadata::get(
+        ConstantInt::get(Type::getInt32Ty(Ctx), Value)));
+    break;
+  default:
+    assert(false && "NYI: Unhandled entry property tag");
   }
+  return MDVals;
+}
 
-  SmallVector<Metadata *> getTagValueAsMetadata(EntryPropsTag Tag,
-                                                uint64_t Value) {
-    SmallVector<Metadata *> MDVals;
-    MDVals.emplace_back(
-        ConstantAsMetadata::get(ConstantInt::get(Type::getInt32Ty(Ctx), Tag)));
-    switch (Tag) {
-    case ShaderFlagsTag:
+static MDTuple *
+getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
+                       const Triple::EnvironmentType ShaderProfile) {
+  SmallVector<Metadata *> MDVals;
+  LLVMContext &Ctx = EP.Entry->getContext();
+  if (EntryShaderFlags != 0)
+    MDVals.append(getTagValueAsMetadata(ShaderFlagsTag, EntryShaderFlags, Ctx));
+
+  if (EP.Entry != nullptr) {
+    // FIXME: support more props.
+    // See https://github.com/llvm/llvm-project/issues/57948.
+    // Add shader kind for lib entries.
+    if (ShaderProfile == Triple::EnvironmentType::Library &&
+        EP.ShaderStage != Triple::EnvironmentType::Library)
+      MDVals.append(getTagValueAsMetadata(ShaderKindTag,
+                                          getShaderStage(EP.ShaderStage), Ctx));
+
+    if (EP.ShaderStage == Triple::EnvironmentType::Compute) {
       MDVals.emplace_back(ConstantAsMetadata::get(
-          ConstantInt::get(Type::getInt64Ty(Ctx), Value)));
-      break;
-    case ShaderKindTag:
-      MDVals.emplace_back(ConstantAsMetadata::get(
-          ConstantInt::get(Type::getInt32Ty(Ctx), Value)));
-      break;
-    default:
-      assert(false && "NYI: Unhandled entry property tag");
+          ConstantInt::get(Type::getInt32Ty(Ctx), NumThreadsTag)));
+      std::vector<Metadata *> NumThreadVals;
+      NumThreadVals.emplace_back(ConstantAsMetadata::get(
+          ConstantInt::get(Type::getInt32Ty(Ctx), EP.NumThreadsX)));
+      NumThreadVals.emplace_back(ConstantAsMetadata::get(
+          ConstantInt::get(Type::getInt32Ty(Ctx), EP.NumThreadsY)));
+      NumThreadVals.emplace_back(ConstantAsMetadata::get(
+          ConstantInt::get(Type::getInt32Ty(Ctx), EP.NumThreadsZ)));
+      MDVals.emplace_back(MDNode::get(Ctx, NumThreadVals));
     }
-    return MDVals;
   }
+  if (MDVals.empty())
+    return nullptr;
+  return MDNode::get(Ctx, MDVals);
+}
 
-  MDTuple *constructEntryPropMetadata() {
-    SmallVector<Metadata *> MDVals;
-    if (EntryShaderFlags != 0)
-      MDVals.append(getTagValueAsMetadata(ShaderFlagsTag, EntryShaderFlags));
-
-    if (Entry != nullptr) {
-      // FIXME: support more props.
-      // See https://github.com/llvm/llvm-project/issues/57948.
-      // Add shader kind for lib entries.
-      if (ShaderProfile == Triple::EnvironmentType::Library &&
-          ShaderStage != Triple::EnvironmentType::Library)
-        MDVals.append(
-            getTagValueAsMetadata(ShaderKindTag, getShaderStage(ShaderStage)));
-
-      if (ShaderStage == Triple::EnvironmentType::Compute) {
-        MDVals.emplace_back(ConstantAsMetadata::get(
-            ConstantInt::get(Type::getInt32Ty(Ctx), NumThreadsTag)));
-        std::vector<Metadata *> NumThreadVals;
-        NumThreadVals.emplace_back(ConstantAsMetadata::get(
-            ConstantInt::get(Type::getInt32Ty(Ctx), NumThreadsX)));
-        NumThreadVals.emplace_back(ConstantAsMetadata::get(
-            ConstantInt::get(Type::getInt32Ty(Ctx), NumThreadsY)));
-        NumThreadVals.emplace_back(ConstantAsMetadata::get(
-            ConstantInt::get(Type::getInt32Ty(Ctx), NumThreadsZ)));
-        MDVals.emplace_back(MDNode::get(Ctx, NumThreadVals));
-      }
-    }
-    if (MDVals.empty())
-      return nullptr;
-    return MDNode::get(Ctx, MDVals);
-  }
-};
-
-static void createEntryMD(Module &M, const uint64_t ShaderFlags,
-                          const dxil::ModuleMetadataInfo &MDAnalysisInfo) {
-  auto &Ctx = M.getContext();
-  // FIXME: generate metadata for resource.
-  MDTuple *MDResources = nullptr;
-  if (auto *NamedResources = M.getNamedMetadata("dx.resources"))
-    MDResources = dyn_cast<MDTuple>(NamedResources->getOperand(0));
-
-  std::vector<MDNode *> EntryFnMDNodes;
-  switch (MDAnalysisInfo.ShaderProfile) {
-  case Triple::EnvironmentType::Library: {
-    // Library has an entry metadata with resource table metadata and all other
-    // MDNodes as null.
-    EntryProperties EP{};
-    // FIXME: ShaderFlagsAnalysis pass needs to collect and provide ShaderFlags
-    // for each entry function. Currently, ShaderFlags value provided by
-    // ShaderFlagsAnalysis pass is created by walking *all* the function
-    // instructions of the module. Is it is correct to use this value for
-    // metadata of the empty library entry?
-    ShaderEntryMDInfo EmptyFunEntryProps(EP, Ctx, MDAnalysisInfo.ShaderProfile,
-                                         MDResources, ShaderFlags);
-    MDTuple *EmptyMDT = EmptyFunEntryProps.getAsMetadata();
-    EntryFnMDNodes.emplace_back(EmptyMDT);
-
-    for (auto EntryProp : MDAnalysisInfo.EntryPropertyVec) {
-      // FIXME: ShaderFlagsAnalysis pass needs to collect and provide
-      // ShaderFlags for each entry function. For now, assume shader flags value
-      // of entry functions being compiled for lib_* shader profile viz.,
-      // EntryPro.Entry is 0.
-      ShaderEntryMDInfo SEP(EntryProp, Ctx, MDAnalysisInfo.ShaderProfile,
-                            nullptr, 0);
-      MDTuple *MDT = SEP.getAsMetadata();
-      EntryFnMDNodes.emplace_back(MDT);
-    }
-  } break;
-  case Triple::EnvironmentType::Compute: {
-    size_t NumEntries = MDAnalysisInfo.EntryPropertyVec.size();
-    assert(NumEntries == 1 &&
-           "Compute shader: One and only one entry expected");
-    EntryProperties EntryProp = MDAnalysisInfo.EntryPropertyVec[0];
-    // ShaderFlagsAnalysis pass needs to collect and provide ShaderFlags for
-    // each entry function. Currently, even though the ShaderFlags value
-    // provided by ShaderFlagsAnalysis pass is created by walking all the
-    // function instructions of the module, it is sufficient to since there is
-    // only one entry function in the module.
-    ShaderEntryMDInfo SEP(EntryProp, Ctx, MDAnalysisInfo.ShaderProfile,
-                          MDResources, ShaderFlags);
-    MDTuple *MDT = SEP.getAsMetadata();
-    EntryFnMDNodes.emplace_back(MDT);
-    break;
-  }
-  case Triple::EnvironmentType::Amplification:
-  case Triple::EnvironmentType::Mesh:
-  case Triple::EnvironmentType::Vertex:
-  case Triple::EnvironmentType::Hull:
-  case Triple::EnvironmentType::Domain:
-  case Triple::EnvironmentType::Geometry:
-  case Triple::EnvironmentType::Pixel: {
-    size_t NumEntries = MDAnalysisInfo.EntryPropertyVec.size();
-    assert(NumEntries == 1 && "non-lib profiles should only have one entry");
-    EntryProperties EntryProp = MDAnalysisInfo.EntryPropertyVec[0];
-    // ShaderFlagsAnalysis pass needs to collect and provide ShaderFlags for
-    // each entry function. Currently, even though the ShaderFlags value
-    // provided by ShaderFlagsAnalysis pass is created by walking all the
-    // function instructions of the module, it is sufficient to since there is
-    // only one entry function in the module.
-    ShaderEntryMDInfo SEP(EntryProp, Ctx, MDAnalysisInfo.ShaderProfile,
-                          MDResources, ShaderFlags);
-    MDTuple *MDT = SEP.getAsMetadata();
-    EntryFnMDNodes.emplace_back(MDT);
-  } break;
-  default:
-    assert(0 && "invalid profile");
-    break;
-  }
+// Each entry point metadata record specifies:
+//  * reference to the entry point function global symbol
+//  * unmangled name
+//  * list of signatures
+//  * list of resources
+//  * list of tag-value pairs of shader capabilities and other properties
+
+MDTuple *constructEntryMetadata(const Function *EntryFn, MDTuple *Signatures,
+                                MDNode *Resources, MDTuple *Properties,
+                                LLVMContext &Ctx) {
+  Metadata *MDVals[5];
+  MDVals[0] =
+      EntryFn ? ValueAsMetadata::get(const_cast<Function *>(EntryFn)) : nullptr;
+  MDVals[1] = MDString::get(Ctx, EntryFn ? EntryFn->getName() : "");
+  MDVals[2] = Signatures;
+  MDVals[3] = Resources;
+  MDVals[4] = Properties;
+  return MDNode::get(Ctx, MDVals);
+}
 
-  NamedMDNode *EntryPointsNamedMD =
-      M.getOrInsertNamedMetadata("dx.entryPoints");
-  for (auto *Entry : EntryFnMDNodes)
-    EntryPointsNamedMD->addOperand(Entry);
+static MDTuple *emitEntryMD(const EntryProperties &EP, MDTuple *Signatures,
+                            MDNode *MDResources,
+                            const uint64_t EntryShaderFlags,
+                            const Triple::EnvironmentType ShaderProfile) {
+  MDTuple *Properties =
+      getEntryPropAsMetadata(EP, EntryShaderFlags, ShaderProfile);
+  return constructEntryMetadata(EP.Entry, Signatures, MDResources, Properties,
+                                EP.Entry->getContext());
 }
 
-static void translateMetadata(Module &M, const DXILResourceMap &DRM,
-                              const dxil::Resources &MDResources,
-                              const ComputedShaderFlags &ShaderFlags,
-                              const dxil::ModuleMetadataInfo &MDAnalysisInfo) {
-  LLVMContext &Ctx = M.getContext();
-  IRBuilder<> IRB(Ctx);
-  if (MDAnalysisInfo.ValidatorVersion.empty()) {
-    // Module has no metadata node signifying valid validator version.
-    // Create metadata dx.valver node with version value of 1.0
-    const VersionTuple DefaultValidatorVer{1, 0};
+static void emitValidatorVersionMD(Module &M, const ModuleMetadataInfo &MMDI) {
+  if (!MMDI.ValidatorVersion.empty()) {
+    LLVMContext &Ctx = M.getContext();
+    IRBuilder<> IRB(Ctx);
     Metadata *MDVals[2];
     MDVals[0] =
-        ConstantAsMetadata::get(IRB.getInt32(DefaultValidatorVer.getMajor()));
+        ConstantAsMetadata::get(IRB.getInt32(MMDI.ValidatorVersion.getMajor()));
     MDVals[1] = ConstantAsMetadata::get(
-        IRB.getInt32(DefaultValidatorVer.getMinor().value_or(0)));
+        IRB.getInt32(MMDI.ValidatorVersion.getMinor().value_or(0)));
     NamedMDNode *ValVerNode = M.getOrInsertNamedMetadata("dx.valver");
+    // Set validator version obtained from DXIL Metadata Analysis pass
+    ValVerNode->clearOperands();
     ValVerNode->addOperand(MDNode::get(Ctx, MDVals));
   }
+}
 
+static void emitShaderModelVersionMD(Module &M,
+                                     const ModuleMetadataInfo &MMDI) {
+  LLVMContext &Ctx = M.getContext();
+  IRBuilder<> IRB(Ctx);
   Metadata *SMVals[3];
-  VersionTuple SM = MDAnalysisInfo.ShaderModelVersion;
-  SMVals[0] =
-      MDString::get(Ctx, getShortShaderStage(MDAnalysisInfo.ShaderProfile));
+  VersionTuple SM = MMDI.ShaderModelVersion;
+  SMVals[0] = MDString::get(Ctx, getShortShaderStage(MMDI.ShaderProfile));
   SMVals[1] = ConstantAsMetadata::get(IRB.getInt32(SM.getMajor()));
   SMVals[2] = ConstantAsMetadata::get(IRB.getInt32(SM.getMinor().value_or(0)));
   NamedMDNode *SMMDNode = M.getOrInsertNamedMetadata("dx.shaderModel");
   SMMDNode->addOperand(MDNode::get(Ctx, SMVals));
+}
 
-  VersionTuple DXILVer = MDAnalysisInfo.DXILVersion;
+static void emitDXILVersionTupleMD(Module &M, const ModuleMetadataInfo &MMDI) {
+  LLVMContext &Ctx = M.getContext();
+  IRBuilder<> IRB(Ctx);
+  VersionTuple DXILVer = MMDI.DXILVersion;
   Metadata *DXILVals[2];
   DXILVals[0] = ConstantAsMetadata::get(IRB.getInt32(DXILVer.getMajor()));
   DXILVals[1] =
       ConstantAsMetadata::get(IRB.getInt32(DXILVer.getMinor().value_or(0)));
   NamedMDNode *DXILVerMDNode = M.getOrInsertNamedMetadata("dx.version");
   DXILVerMDNode->addOperand(MDNode::get(Ctx, DXILVals));
+}
 
-  emitResourceMetadata(M, DRM, MDResources);
+static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD,
+                                        uint64_t ShaderFlags) {
+  LLVMContext &Ctx = M.getContext();
+  MDTuple *Properties = nullptr;
+  if (ShaderFlags != 0) {
+    SmallVector<Metadata *> MDVals;
+    // FIXME: ShaderFlagsAnalysis pass needs to collect and provide
+    // ShaderFlags for each entry function. Currently, ShaderFlags value
+    // provided by ShaderFlagsAnalysis pass is created by walking *all* the
+    // function instructions of the module. Is it is correct to use this value
+    // for metadata of the empty library entry?
+    MDVals.append(getTagValueAsMetadata(ShaderFlagsTag, ShaderFlags, Ctx));
+    Properties = MDNode::get(Ctx, MDVals);
+  }
+  // Library has an entry metadata with resource table metadata and all other
+  // MDNodes as null.
+  return constructEntryMetadata(nullptr, nullptr, RMD, Properties, Ctx);
+}
+
+static void translateMetadata(Module &M, const DXILResourceMap &DRM,
+                              const Resources &MDResources,
+                              const ComputedShaderFlags &ShaderFlags,
+                              const ModuleMetadataInfo &MMDI) {
+  LLVMContext &Ctx = M.getContext();
+  IRBuilder<> IRB(Ctx);
+  SmallVector<MDNode *> EntryFnMDNodes;
+
+  emitValidatorVersionMD(M, MMDI);
+  emitShaderModelVersionMD(M, MMDI);
+  emitDXILVersionTupleMD(M, MMDI);
+  NamedMDNode *NamedResourceMD = emitResourceMetadata(M, DRM, MDResources);
+  auto *ResourceMD =
+      (NamedResourceMD != nullptr) ? NamedResourceMD->getOperand(0) : nullptr;
+  // FIXME: Add support to construct Signatures
+  // See https://github.com/llvm/llvm-project/issues/57928
+  MDTuple *Signatures = nullptr;
+
+  if (MMDI.ShaderProfile == Triple::EnvironmentType::Library)
+    EntryFnMDNodes.emplace_back(
+        emitTopLevelLibraryNode(M, ResourceMD, ShaderFlags));
+  else if (MMDI.EntryPropertyVec.size() > 1) {
+    M.getContext().diagnose(DiagnosticInfoModuleFormat(
+        M, "Non-library shader: One and only one entry expected"));
+  }
+
+  for (const EntryProperties &EntryProp : MMDI.EntryPropertyVec) {
+    // FIXME: ShaderFlagsAnalysis pass needs to collect and provide
+    // ShaderFlags for each entry function. For now, assume shader flags value
+    // of entry functions being compiled for lib_* shader profile viz.,
+    // EntryPro.Entry is 0.
+    uint64_t EntryShaderFlags =
+        (MMDI.ShaderProfile == Triple::EnvironmentType::Library) ? 0
+                                                                 : ShaderFlags;
+    EntryFnMDNodes.emplace_back(emitEntryMD(EntryProp, Signatures, ResourceMD,
+                                            EntryShaderFlags,
+                                            MMDI.ShaderProfile));
+  }
 
-  createEntryMD(M, static_cast<uint64_t>(ShaderFlags), MDAnalysisInfo);
+  NamedMDNode *EntryPointsNamedMD =
+      M.getOrInsertNamedMetadata("dx.entryPoints");
+  for (auto *Entry : EntryFnMDNodes)
+    EntryPointsNamedMD->addOperand(Entry);
 }
 
 PreservedAnalyses DXILTranslateMetadata::run(Module &M,
@@ -340,10 +347,9 @@ PreservedAnalyses DXILTranslateMetadata::run(Module &M,
   const dxil::Resources &MDResources = MAM.getResult<DXILResourceMDAnalysis>(M);
   const ComputedShaderFlags &ShaderFlags =
       MAM.getResult<ShaderFlagsAnalysis>(M);
-  const dxil::ModuleMetadataInfo MetadataInfo =
-      MAM.getResult<DXILMetadataAnalysis>(M);
+  const dxil::ModuleMetadataInfo MMDI = MAM.getResult<DXILMetadataAnalysis>(M);
 
-  translateMetadata(M, DRM, MDResources, ShaderFlags, MetadataInfo);
+  translateMetadata(M, DRM, MDResources, ShaderFlags, MMDI);
 
   return PreservedAnalyses::all();
 }
@@ -371,10 +377,10 @@ class DXILTranslateMetadataLegacy : public ModulePass {
         getAnalysis<DXILResourceMDWrapper>().getDXILResource();
     const ComputedShaderFlags &ShaderFlags =
         getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags();
-    dxil::ModuleMetadataInfo MetadataInfo =
+    dxil::ModuleMetadataInfo MMDI =
         getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
 
-    translateMetadata(M, DRM, MDResources, ShaderFlags, MetadataInfo);
+    translateMetadata(M, DRM, MDResources, ShaderFlags, MMDI);
     return true;
   }
 };
diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
index 1ca75661f73d15..53f5bc60b5cfec 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
@@ -28,7 +28,6 @@
 #include "llvm/CodeGen/TargetPassConfig.h"
 #include "llvm/IR/IRPrintingPasses.h"
 #include "llvm/IR/LegacyPassManager.h"
-#include "llvm/InitializePasses.h"
 #include "llvm/MC/MCSectionDXContainer.h"
 #include "llvm/MC/SectionKind.h"
 #include "llvm/MC/TargetRegistry.h"
@@ -53,7 +52,6 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
   initializeDXContainerGlobalsPass(*PR);
   initializeDXILOpLoweringLegacyPass(*PR);
   initializeDXILTranslateMetadataLegacyPass(*PR);
-  initializeDXILMetadataAnalysisWrapperPassPass(*PR);
   initializeDXILResourceMDWrapperPass(*PR);
   initializeShaderFlagsAnalysisWrapperPass(*PR);
   initializeDXILFinalizeLinkageLegacyPass(*PR);
diff --git a/llvm/test/CodeGen/DirectX/CreateHandle.ll b/llvm/test/CodeGen/DirectX/CreateHandle.ll
index 4653baf8a3b21a..40b3b2c7122722 100644
--- a/llvm/test/CodeGen/DirectX/CreateHandle.ll
+++ b/llvm/test/CodeGen/DirectX/CreateHandle.ll
@@ -14,7 +14,7 @@ target triple = "dxil-pc-shadermodel6.0-compute"
 
 declare i32 @some_val();
 
-define void @test_buffers() #0 {
+define void @test_buffers() {
   ; RWBuffer<float4> Buf : register(u5, space3)
   %typed0 = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
               @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_1_0_0(
@@ -68,4 +68,4 @@ define void @test_buffers() #0 {
 ; CHECK-DAG: [[SRVMD]] = !{!{{[0-9]+}}, !{{[0-9]+}}, !{{[0-9]+}}, !{{[0-9]+}}}
 ; CHECK-DAG: [[UAVMD]] = !{!{{[0-9]+}}, !{{[0-9]+}}}
 
-attributes #0 = { nocallback nofree nosync nounwind willreturn memory(none) "hlsl.numthreads"="1,2,1" "hlsl.shader"="compute"}
+attributes #0 = { nocallback nofree nosync nounwind willreturn memory(none) }
diff --git a/llvm/test/CodeGen/DirectX/CreateHandleFromBinding.ll b/llvm/test/CodeGen/DirectX/CreateHandleFromBinding.ll
index a082e0c197723c..dbdd2e61df7a3b 100644
--- a/llvm/test/CodeGen/DirectX/CreateHandleFromBinding.ll
+++ b/llvm/test/CodeGen/DirectX/CreateHandleFromBinding.ll
@@ -14,7 +14,7 @@ target triple = "dxil-pc-shadermodel6.6-compute"
 
 declare i32 @some_val();
 
-define void @test_bindings() #0 {
+define void @test_bindings() {
   ; RWBuffer<float4> Buf : register(u5, space3)
   %typed0 = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
               @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_1_0_0(
@@ -73,4 +73,4 @@ define void @test_bindings() #0 {
 ; CHECK-DAG: [[SRVMD]] = !{!{{[0-9]+}}, !{{[0-9]+}}, !{{[0-9]+}}, !{{[0-9]+}}}
 ; CHECK-DAG: [[UAVMD]] = !{!{{[0-9]+}}, !{{[0-9]+}}}
 
-attributes #0 = { nocallback nofree nosync nounwind willreturn memory(none) "hlsl.numthreads"="1,2,1" "hlsl.shader"="compute"}
+attributes #0 = { nocallback nofree nosync nounwind willreturn memory(none) }
diff --git a/llvm/test/CodeGen/DirectX/Metadata/lib-entries.ll b/llvm/test/CodeGen/DirectX/Metadata/lib-entries.ll
index 3ce15eea6d437c..e2f2a482b8ca0d 100644
--- a/llvm/test/CodeGen/DirectX/Metadata/lib-entries.ll
+++ b/llvm/test/CodeGen/DirectX/Metadata/lib-entries.ll
@@ -2,12 +2,10 @@
 target triple = "dxil-pc-shadermodel6.8-library"
 
 
-; CHECK: !dx.valver = !{![[DXVALVER:[0-9]+]]}
 ; CHECK: !dx.shaderModel = !{![[SM:[0-9]+]]}
 ; CHECK: !dx.version = !{![[DXVER:[0-9]+]]}
 ; CHECK: !dx.entryPoints = !{![[LIB:[0-9]+]], ![[AS:[0-9]+]], ![[MS:[0-9]+]], ![[CS:[0-9]+]]}
 
-; CHECK: ![[DXVALVER]] = !{i32 1, i32 0}
 ; CHECK: ![[SM]] = !{!"lib", i32 6, i32 8}
 ; CHECK: ![[DXVER]] = !{i32 1, i32 8}
 ; CHECK: ![[LIB]] = !{null, !"", null, null, null}
diff --git a/llvm/test/CodeGen/DirectX/Metadata/multiple-entries-cs-error.ll b/llvm/test/CodeGen/DirectX/Metadata/multiple-entries-cs-error.ll
new file mode 100644
index 00000000000000..9697d4389a888a
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/Metadata/multiple-entries-cs-error.ll
@@ -0,0 +1,23 @@
+; RUN: not opt -S  -S -dxil-translate-metadata %s 2>&1 | FileCheck %s
+target triple = "dxil-pc-shadermodel6.8-compute"
+
+; CHECK: Non-library shader: One and only one entry expected
+
+define void @entry_as() #0 {
+entry:
+  ret void
+}
+
+define i32 @entry_ms(i32 %a) #1 {
+entry:
+  ret i32 %a
+}
+
+define float @entry_cs(float %f) #3 {
+entry:
+  ret float %f
+}
+
+attributes #0 = { noinline nounwind "hlsl.shader"="amplification" }
+attributes #1 = { noinline nounwind "hlsl.shader"="mesh" }
+attributes #3 = { noinline nounwind "hlsl.numthreads"="1,2,1" "hlsl.shader"="compute" }

>From c9ee494484b8d2379d4719db2bf56611720763e2 Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Fri, 13 Sep 2024 10:20:23 -0400
Subject: [PATCH 5/8] Add a check to ensure entry shader stage is the same as
 target profile for non-library target profile compilation.

Add a test to verify the check triggers diagnostic report appropriately.
---
 llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp    |  7 +++++++
 .../CodeGen/DirectX/Metadata/target-profile-error.ll | 12 ++++++++++++
 2 files changed, 19 insertions(+)
 create mode 100644 llvm/test/CodeGen/DirectX/Metadata/target-profile-error.ll

diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index f143de602b6dda..fd8b2791e371e4 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -330,6 +330,13 @@ static void translateMetadata(Module &M, const DXILResourceMap &DRM,
     uint64_t EntryShaderFlags =
         (MMDI.ShaderProfile == Triple::EnvironmentType::Library) ? 0
                                                                  : ShaderFlags;
+    if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
+      if (EntryProp.ShaderStage != MMDI.ShaderProfile) {
+        M.getContext().diagnose(DiagnosticInfoModuleFormat(
+            M, "Non-library shader: Stage of Shader entry different from "
+               "target profile"));
+      }
+    }
     EntryFnMDNodes.emplace_back(emitEntryMD(EntryProp, Signatures, ResourceMD,
                                             EntryShaderFlags,
                                             MMDI.ShaderProfile));
diff --git a/llvm/test/CodeGen/DirectX/Metadata/target-profile-error.ll b/llvm/test/CodeGen/DirectX/Metadata/target-profile-error.ll
new file mode 100644
index 00000000000000..22a47c356fa70a
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/Metadata/target-profile-error.ll
@@ -0,0 +1,12 @@
+; RUN: not opt -S -dxil-translate-metadata %s 2>&1 | FileCheck %s
+
+target triple = "dxil-pc-shadermodel6.6-pixel"
+
+; CHECK: Non-library shader: Stage of Shader entry different from target profile
+
+define void @entry() #0 {
+entry:
+  ret void
+}
+
+attributes #0 = { noinline nounwind "exp-shader"="cs" "hlsl.numthreads"="1,2,1" "hlsl.shader"="compute" }

>From 166a9dd78d37260b6675bb1649c969238431101d Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Fri, 13 Sep 2024 15:49:33 -0400
Subject: [PATCH 6/8] Changes upon rebasing

---
 llvm/lib/Target/DirectX/DXILPrepare.cpp          | 3 +++
 llvm/lib/Target/DirectX/DirectXTargetMachine.cpp | 1 +
 2 files changed, 4 insertions(+)

diff --git a/llvm/lib/Target/DirectX/DXILPrepare.cpp b/llvm/lib/Target/DirectX/DXILPrepare.cpp
index 1a766c5fb7b4a8..6cedd308ca24f2 100644
--- a/llvm/lib/Target/DirectX/DXILPrepare.cpp
+++ b/llvm/lib/Target/DirectX/DXILPrepare.cpp
@@ -248,7 +248,10 @@ class DXILPrepareModule : public ModulePass {
   DXILPrepareModule() : ModulePass(ID) {}
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     AU.setPreservesAll();
+    AU.addPreserved<ShaderFlagsAnalysisWrapper>();
+    AU.addPreserved<DXILResourceMDWrapper>();
     AU.addRequired<DXILMetadataAnalysisWrapperPass>();
+    AU.addPreserved<DXILResourceWrapperPass>();
   }
   static char ID; // Pass identification.
 };
diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
index 53f5bc60b5cfec..606022a9835f04 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
@@ -28,6 +28,7 @@
 #include "llvm/CodeGen/TargetPassConfig.h"
 #include "llvm/IR/IRPrintingPasses.h"
 #include "llvm/IR/LegacyPassManager.h"
+#include "llvm/InitializePasses.h"
 #include "llvm/MC/MCSectionDXContainer.h"
 #include "llvm/MC/SectionKind.h"
 #include "llvm/MC/TargetRegistry.h"

>From cb8a3951f179e2444783aa4047e48168d6005813 Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Tue, 17 Sep 2024 16:17:50 -0400
Subject: [PATCH 7/8] Address PR feedback

---
 .../llvm/Analysis/DXILMetadataAnalysis.h      |  4 +--
 llvm/lib/Target/DirectX/DXILPrepare.cpp       |  3 --
 .../Target/DirectX/DXILTranslateMetadata.cpp  | 28 ++++++++-----------
 .../DirectX/Metadata/target-profile-error.ll  |  2 +-
 4 files changed, 14 insertions(+), 23 deletions(-)

diff --git a/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h b/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h
index cb442669a24dfe..cb535ac14f1c61 100644
--- a/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h
+++ b/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h
@@ -23,7 +23,7 @@ namespace dxil {
 struct EntryProperties {
   const Function *Entry{nullptr};
   // Specific target shader stage may be specified for entry functions
-  Triple::EnvironmentType ShaderStage = Triple::UnknownEnvironment;
+  Triple::EnvironmentType ShaderStage{Triple::UnknownEnvironment};
   unsigned NumThreadsX{0}; // X component
   unsigned NumThreadsY{0}; // Y component
   unsigned NumThreadsZ{0}; // Z component
@@ -34,7 +34,7 @@ struct EntryProperties {
 struct ModuleMetadataInfo {
   VersionTuple DXILVersion{};
   VersionTuple ShaderModelVersion{};
-  Triple::EnvironmentType ShaderProfile = Triple::UnknownEnvironment;
+  Triple::EnvironmentType ShaderProfile{Triple::UnknownEnvironment};
   VersionTuple ValidatorVersion{};
   SmallVector<EntryProperties> EntryPropertyVec{};
   void print(raw_ostream &OS) const;
diff --git a/llvm/lib/Target/DirectX/DXILPrepare.cpp b/llvm/lib/Target/DirectX/DXILPrepare.cpp
index 6cedd308ca24f2..1a766c5fb7b4a8 100644
--- a/llvm/lib/Target/DirectX/DXILPrepare.cpp
+++ b/llvm/lib/Target/DirectX/DXILPrepare.cpp
@@ -248,10 +248,7 @@ class DXILPrepareModule : public ModulePass {
   DXILPrepareModule() : ModulePass(ID) {}
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     AU.setPreservesAll();
-    AU.addPreserved<ShaderFlagsAnalysisWrapper>();
-    AU.addPreserved<DXILResourceMDWrapper>();
     AU.addRequired<DXILMetadataAnalysisWrapperPass>();
-    AU.addPreserved<DXILResourceWrapperPass>();
   }
   static char ID; // Pass identification.
 };
diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index fd8b2791e371e4..a82b522773ab22 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -24,6 +24,7 @@
 #include "llvm/IR/Module.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
+#include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/VersionTuple.h"
 #include "llvm/TargetParser/Triple.h"
 #include <cstdint>
@@ -34,7 +35,7 @@ using namespace llvm::dxil;
 /// A simple Wrapper DiagnosticInfo that generates Module-level diagnostic
 class DiagnosticInfoModuleFormat : public DiagnosticInfo {
 private:
-  Twine Msg;
+  const Twine Msg;
   const Module &Mod;
 
 public:
@@ -45,12 +46,6 @@ class DiagnosticInfoModuleFormat : public DiagnosticInfo {
                              DiagnosticSeverity Severity = DS_Error)
       : DiagnosticInfo(DK_Unsupported, Severity), Msg(Msg), Mod(M) {}
 
-  static bool classof(const DiagnosticInfo *DI) {
-    return DI->getKind() == DK_Unsupported;
-  }
-
-  const Twine &getMessage() const { return Msg; }
-
   void print(DiagnosticPrinter &DP) const override {
     std::string Str;
     raw_string_ostream OS(Str);
@@ -166,7 +161,7 @@ getTagValueAsMetadata(EntryPropsTag Tag, uint64_t Value, LLVMContext &Ctx) {
         ConstantInt::get(Type::getInt32Ty(Ctx), Value)));
     break;
   default:
-    assert(false && "NYI: Unhandled entry property tag");
+    llvm_unreachable("NYI: Unhandled entry property tag");
   }
   return MDVals;
 }
@@ -191,13 +186,12 @@ getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
     if (EP.ShaderStage == Triple::EnvironmentType::Compute) {
       MDVals.emplace_back(ConstantAsMetadata::get(
           ConstantInt::get(Type::getInt32Ty(Ctx), NumThreadsTag)));
-      std::vector<Metadata *> NumThreadVals;
-      NumThreadVals.emplace_back(ConstantAsMetadata::get(
-          ConstantInt::get(Type::getInt32Ty(Ctx), EP.NumThreadsX)));
-      NumThreadVals.emplace_back(ConstantAsMetadata::get(
-          ConstantInt::get(Type::getInt32Ty(Ctx), EP.NumThreadsY)));
-      NumThreadVals.emplace_back(ConstantAsMetadata::get(
-          ConstantInt::get(Type::getInt32Ty(Ctx), EP.NumThreadsZ)));
+      Metadata *NumThreadVals[] = {ConstantAsMetadata::get(ConstantInt::get(
+                                       Type::getInt32Ty(Ctx), EP.NumThreadsX)),
+                                   ConstantAsMetadata::get(ConstantInt::get(
+                                       Type::getInt32Ty(Ctx), EP.NumThreadsY)),
+                                   ConstantAsMetadata::get(ConstantInt::get(
+                                       Type::getInt32Ty(Ctx), EP.NumThreadsZ))};
       MDVals.emplace_back(MDNode::get(Ctx, NumThreadVals));
     }
   }
@@ -333,8 +327,8 @@ static void translateMetadata(Module &M, const DXILResourceMap &DRM,
     if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
       if (EntryProp.ShaderStage != MMDI.ShaderProfile) {
         M.getContext().diagnose(DiagnosticInfoModuleFormat(
-            M, "Non-library shader: Stage of Shader entry different from "
-               "target profile"));
+            M, "Non-library shader: Stage of shader entry different from "
+               "shader target profile"));
       }
     }
     EntryFnMDNodes.emplace_back(emitEntryMD(EntryProp, Signatures, ResourceMD,
diff --git a/llvm/test/CodeGen/DirectX/Metadata/target-profile-error.ll b/llvm/test/CodeGen/DirectX/Metadata/target-profile-error.ll
index 22a47c356fa70a..a6061d0bf47ff9 100644
--- a/llvm/test/CodeGen/DirectX/Metadata/target-profile-error.ll
+++ b/llvm/test/CodeGen/DirectX/Metadata/target-profile-error.ll
@@ -2,7 +2,7 @@
 
 target triple = "dxil-pc-shadermodel6.6-pixel"
 
-; CHECK: Non-library shader: Stage of Shader entry different from target profile
+; CHECK: Non-library shader: Stage of shader entry different from shader target profile
 
 define void @entry() #0 {
 entry:

>From 292f7e588548646fb3e4800fd9edb4da44f5dd16 Mon Sep 17 00:00:00 2001
From: "S. Bharadwaj Yadavalli" <Bharadwaj.Yadavalli at microsoft.com>
Date: Thu, 19 Sep 2024 12:13:47 -0400
Subject: [PATCH 8/8] Address more PR feedback - Rename class
 DiagnosticInfoModuleFormat as class DiagnosticInfoTranslateMD   and enclose
 it in an anonymous namespace. - Print message directly to DiagnosticPrinter
 in DiagnosticInfoTranslateMD::print() - Change enum EntryPropsTag to scop[ed
 enum class - List all ENtryPropsTags in case statement - Moved comment inside
 constructEntryMetadata() - Changed emitValidatorVersionMD() to return early
 as appropriate.

---
 .../Target/DirectX/DXILTranslateMetadata.cpp  | 99 +++++++++++--------
 .../DirectX/Metadata/target-profile-error.ll  |  2 +-
 2 files changed, 58 insertions(+), 43 deletions(-)

diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index a82b522773ab22..9f1a78b6cfa542 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -12,6 +12,7 @@
 #include "DXILShaderFlags.h"
 #include "DirectX.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/Twine.h"
 #include "llvm/Analysis/DXILMetadataAnalysis.h"
 #include "llvm/Analysis/DXILResource.h"
 #include "llvm/IR/Constants.h"
@@ -32,30 +33,45 @@
 using namespace llvm;
 using namespace llvm::dxil;
 
+namespace {
 /// A simple Wrapper DiagnosticInfo that generates Module-level diagnostic
-class DiagnosticInfoModuleFormat : public DiagnosticInfo {
+/// for TranslateMetadata pass
+class DiagnosticInfoTranslateMD : public DiagnosticInfo {
 private:
-  const Twine Msg;
+  const Twine &Msg;
   const Module &Mod;
 
 public:
   /// \p M is the module for which the diagnostic is being emitted. \p Msg is
   /// the message to show. Note that this class does not copy this message, so
   /// this reference must be valid for the whole life time of the diagnostic.
-  DiagnosticInfoModuleFormat(const Module &M, const Twine &Msg,
-                             DiagnosticSeverity Severity = DS_Error)
+  DiagnosticInfoTranslateMD(const Module &M, const Twine &Msg,
+                            DiagnosticSeverity Severity = DS_Error)
       : DiagnosticInfo(DK_Unsupported, Severity), Msg(Msg), Mod(M) {}
 
   void print(DiagnosticPrinter &DP) const override {
-    std::string Str;
-    raw_string_ostream OS(Str);
-
-    OS << Mod.getName() << ": " << Msg << '\n';
-    OS.flush();
-    DP << Str;
+    DP << Mod.getName() << ": " << Msg << '\n';
   }
 };
 
+enum class EntryPropsTag {
+  ShaderFlags = 0,
+  GSState,
+  DSState,
+  HSState,
+  NumThreads,
+  AutoBindingSpace,
+  RayPayloadSize,
+  RayAttribSize,
+  ShaderKind,
+  MSState,
+  ASStateTag,
+  WaveSize,
+  EntryRootSig,
+};
+
+} // namespace
+
 static NamedMDNode *emitResourceMetadata(Module &M, const DXILResourceMap &DRM,
                                          const dxil::Resources &MDResources) {
   LLVMContext &Context = M.getContext();
@@ -128,39 +144,31 @@ static uint32_t getShaderStage(Triple::EnvironmentType Env) {
   return (uint32_t)Env - (uint32_t)llvm::Triple::Pixel;
 }
 
-namespace {
-enum EntryPropsTag {
-  ShaderFlagsTag = 0,
-  GSStateTag,
-  DSStateTag,
-  HSStateTag,
-  NumThreadsTag,
-  AutoBindingSpaceTag,
-  RayPayloadSizeTag,
-  RayAttribSizeTag,
-  ShaderKindTag,
-  MSStateTag,
-  ASStateTag,
-  WaveSizeTag,
-  EntryRootSigTag,
-};
-} // namespace
-
 static SmallVector<Metadata *>
 getTagValueAsMetadata(EntryPropsTag Tag, uint64_t Value, LLVMContext &Ctx) {
   SmallVector<Metadata *> MDVals;
-  MDVals.emplace_back(
-      ConstantAsMetadata::get(ConstantInt::get(Type::getInt32Ty(Ctx), Tag)));
+  MDVals.emplace_back(ConstantAsMetadata::get(
+      ConstantInt::get(Type::getInt32Ty(Ctx), static_cast<int>(Tag))));
   switch (Tag) {
-  case ShaderFlagsTag:
+  case EntryPropsTag::ShaderFlags:
     MDVals.emplace_back(ConstantAsMetadata::get(
         ConstantInt::get(Type::getInt64Ty(Ctx), Value)));
     break;
-  case ShaderKindTag:
+  case EntryPropsTag::ShaderKind:
     MDVals.emplace_back(ConstantAsMetadata::get(
         ConstantInt::get(Type::getInt32Ty(Ctx), Value)));
     break;
-  default:
+  case EntryPropsTag::GSState:
+  case EntryPropsTag::DSState:
+  case EntryPropsTag::HSState:
+  case EntryPropsTag::NumThreads:
+  case EntryPropsTag::AutoBindingSpace:
+  case EntryPropsTag::RayPayloadSize:
+  case EntryPropsTag::RayAttribSize:
+  case EntryPropsTag::MSState:
+  case EntryPropsTag::ASStateTag:
+  case EntryPropsTag::WaveSize:
+  case EntryPropsTag::EntryRootSig:
     llvm_unreachable("NYI: Unhandled entry property tag");
   }
   return MDVals;
@@ -172,7 +180,8 @@ getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
   SmallVector<Metadata *> MDVals;
   LLVMContext &Ctx = EP.Entry->getContext();
   if (EntryShaderFlags != 0)
-    MDVals.append(getTagValueAsMetadata(ShaderFlagsTag, EntryShaderFlags, Ctx));
+    MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderFlags,
+                                        EntryShaderFlags, Ctx));
 
   if (EP.Entry != nullptr) {
     // FIXME: support more props.
@@ -180,12 +189,12 @@ getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
     // Add shader kind for lib entries.
     if (ShaderProfile == Triple::EnvironmentType::Library &&
         EP.ShaderStage != Triple::EnvironmentType::Library)
-      MDVals.append(getTagValueAsMetadata(ShaderKindTag,
+      MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderKind,
                                           getShaderStage(EP.ShaderStage), Ctx));
 
     if (EP.ShaderStage == Triple::EnvironmentType::Compute) {
-      MDVals.emplace_back(ConstantAsMetadata::get(
-          ConstantInt::get(Type::getInt32Ty(Ctx), NumThreadsTag)));
+      MDVals.emplace_back(ConstantAsMetadata::get(ConstantInt::get(
+          Type::getInt32Ty(Ctx), static_cast<int>(EntryPropsTag::NumThreads))));
       Metadata *NumThreadVals[] = {ConstantAsMetadata::get(ConstantInt::get(
                                        Type::getInt32Ty(Ctx), EP.NumThreadsX)),
                                    ConstantAsMetadata::get(ConstantInt::get(
@@ -282,7 +291,8 @@ static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD,
     // provided by ShaderFlagsAnalysis pass is created by walking *all* the
     // function instructions of the module. Is it is correct to use this value
     // for metadata of the empty library entry?
-    MDVals.append(getTagValueAsMetadata(ShaderFlagsTag, ShaderFlags, Ctx));
+    MDVals.append(
+        getTagValueAsMetadata(EntryPropsTag::ShaderFlags, ShaderFlags, Ctx));
     Properties = MDNode::get(Ctx, MDVals);
   }
   // Library has an entry metadata with resource table metadata and all other
@@ -312,7 +322,7 @@ static void translateMetadata(Module &M, const DXILResourceMap &DRM,
     EntryFnMDNodes.emplace_back(
         emitTopLevelLibraryNode(M, ResourceMD, ShaderFlags));
   else if (MMDI.EntryPropertyVec.size() > 1) {
-    M.getContext().diagnose(DiagnosticInfoModuleFormat(
+    M.getContext().diagnose(DiagnosticInfoTranslateMD(
         M, "Non-library shader: One and only one entry expected"));
   }
 
@@ -326,9 +336,14 @@ static void translateMetadata(Module &M, const DXILResourceMap &DRM,
                                                                  : ShaderFlags;
     if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
       if (EntryProp.ShaderStage != MMDI.ShaderProfile) {
-        M.getContext().diagnose(DiagnosticInfoModuleFormat(
-            M, "Non-library shader: Stage of shader entry different from "
-               "shader target profile"));
+        M.getContext().diagnose(DiagnosticInfoTranslateMD(
+            M,
+            "Shader stage '" +
+                Twine(getShortShaderStage(EntryProp.ShaderStage) +
+                      "' for entry '" + Twine(EntryProp.Entry->getName()) +
+                      "' different from specified target profile '" +
+                      Twine(Triple::getEnvironmentTypeName(MMDI.ShaderProfile) +
+                            "'"))));
       }
     }
     EntryFnMDNodes.emplace_back(emitEntryMD(EntryProp, Signatures, ResourceMD,
diff --git a/llvm/test/CodeGen/DirectX/Metadata/target-profile-error.ll b/llvm/test/CodeGen/DirectX/Metadata/target-profile-error.ll
index a6061d0bf47ff9..671406cb5d3644 100644
--- a/llvm/test/CodeGen/DirectX/Metadata/target-profile-error.ll
+++ b/llvm/test/CodeGen/DirectX/Metadata/target-profile-error.ll
@@ -2,7 +2,7 @@
 
 target triple = "dxil-pc-shadermodel6.6-pixel"
 
-; CHECK: Non-library shader: Stage of shader entry different from shader target profile
+; CHECK: Shader stage 'cs' for entry 'entry' different from specified target profile 'pixel'
 
 define void @entry() #0 {
 entry:



More information about the llvm-commits mailing list