[llvm] 1ff5f32 - [DXIL] Add support for root signature flag element in DXContainer (#123147)

via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 13 14:16:05 PST 2025


Author: joaosaffran
Date: 2025-02-13T14:16:01-08:00
New Revision: 1ff5f328d9824694cc356ebf78adad8816a6de86

URL: https://github.com/llvm/llvm-project/commit/1ff5f328d9824694cc356ebf78adad8816a6de86
DIFF: https://github.com/llvm/llvm-project/commit/1ff5f328d9824694cc356ebf78adad8816a6de86.diff

LOG: [DXIL] Add support for root signature flag element in DXContainer (#123147)

Adding support for Root Signature Flags Element extraction and writing
to DXContainer.
- Adding an analysis to deal with RootSignature metadata definition
- Adding validation for Flag
- writing RootSignature blob into DXIL

Closes: [126632](https://github.com/llvm/llvm-project/issues/126632)

---------

Co-authored-by: joaosaffran <joao.saffran at microsoft.com>

Added: 
    llvm/lib/Target/DirectX/DXILRootSignature.cpp
    llvm/lib/Target/DirectX/DXILRootSignature.h
    llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Error.ll
    llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags-Error.ll
    llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags-Validation-Error.ll
    llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags.ll
    llvm/test/CodeGen/DirectX/ContainerData/RootSignature-MultipleEntryFunctions.ll
    llvm/test/CodeGen/DirectX/ContainerData/RootSignature-NullFunction-Error.ll
    llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootElement-Error.ll

Modified: 
    llvm/include/llvm/BinaryFormat/DXContainer.h
    llvm/include/llvm/MC/DXContainerRootSignature.h
    llvm/include/llvm/ObjectYAML/DXContainerYAML.h
    llvm/lib/MC/DXContainerRootSignature.cpp
    llvm/lib/Object/DXContainer.cpp
    llvm/lib/ObjectYAML/DXContainerEmitter.cpp
    llvm/lib/ObjectYAML/DXContainerYAML.cpp
    llvm/lib/Target/DirectX/CMakeLists.txt
    llvm/lib/Target/DirectX/DXContainerGlobals.cpp
    llvm/lib/Target/DirectX/DirectX.h
    llvm/lib/Target/DirectX/DirectXPassRegistry.def
    llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
    llvm/test/CodeGen/DirectX/llc-pipeline.ll
    llvm/tools/obj2yaml/dxcontainer2yaml.cpp
    llvm/unittests/Object/DXContainerTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/BinaryFormat/DXContainer.h b/llvm/include/llvm/BinaryFormat/DXContainer.h
index fbab066bf4517..bd5a796c0b31c 100644
--- a/llvm/include/llvm/BinaryFormat/DXContainer.h
+++ b/llvm/include/llvm/BinaryFormat/DXContainer.h
@@ -14,8 +14,6 @@
 #define LLVM_BINARYFORMAT_DXCONTAINER_H
 
 #include "llvm/ADT/StringRef.h"
-#include "llvm/Support/BinaryStreamError.h"
-#include "llvm/Support/Error.h"
 #include "llvm/Support/SwapByteOrder.h"
 #include "llvm/TargetParser/Triple.h"
 
@@ -550,18 +548,10 @@ static_assert(sizeof(ProgramSignatureElement) == 32,
 
 struct RootSignatureValidations {
 
-  static Expected<uint32_t> validateRootFlag(uint32_t Flags) {
-    if ((Flags & ~0x80000fff) != 0)
-      return llvm::make_error<BinaryStreamError>("Invalid Root Signature flag");
-    return Flags;
-  }
-
-  static Expected<uint32_t> validateVersion(uint32_t Version) {
-    if (Version == 1 || Version == 2)
-      return Version;
+  static bool isValidRootFlag(uint32_t Flags) { return (Flags & ~0xfff) == 0; }
 
-    return llvm::make_error<BinaryStreamError>(
-        "Invalid Root Signature Version");
+  static bool isValidVersion(uint32_t Version) {
+    return (Version == 1 || Version == 2);
   }
 };
 

diff  --git a/llvm/include/llvm/MC/DXContainerRootSignature.h b/llvm/include/llvm/MC/DXContainerRootSignature.h
index e1a9be5fc52d8..e414112498798 100644
--- a/llvm/include/llvm/MC/DXContainerRootSignature.h
+++ b/llvm/include/llvm/MC/DXContainerRootSignature.h
@@ -14,7 +14,7 @@ namespace llvm {
 class raw_ostream;
 
 namespace mcdxbc {
-struct RootSignatureHeader {
+struct RootSignatureDesc {
   uint32_t Version = 2;
   uint32_t NumParameters = 0;
   uint32_t RootParametersOffset = 0;
@@ -22,7 +22,7 @@ struct RootSignatureHeader {
   uint32_t StaticSamplersOffset = 0;
   uint32_t Flags = 0;
 
-  void write(raw_ostream &OS);
+  void write(raw_ostream &OS) const;
 };
 } // namespace mcdxbc
 } // namespace llvm

diff  --git a/llvm/include/llvm/ObjectYAML/DXContainerYAML.h b/llvm/include/llvm/ObjectYAML/DXContainerYAML.h
index 0200f5cb196ff..ecad35e82b155 100644
--- a/llvm/include/llvm/ObjectYAML/DXContainerYAML.h
+++ b/llvm/include/llvm/ObjectYAML/DXContainerYAML.h
@@ -74,9 +74,9 @@ struct ShaderHash {
 };
 
 #define ROOT_ELEMENT_FLAG(Num, Val) bool Val = false;
-struct RootSignatureDesc {
-  RootSignatureDesc() = default;
-  RootSignatureDesc(const object::DirectX::RootSignature &Data);
+struct RootSignatureYamlDesc {
+  RootSignatureYamlDesc() = default;
+  RootSignatureYamlDesc(const object::DirectX::RootSignature &Data);
 
   uint32_t Version;
   uint32_t NumParameters;
@@ -176,7 +176,7 @@ struct Part {
   std::optional<ShaderHash> Hash;
   std::optional<PSVInfo> Info;
   std::optional<DXContainerYAML::Signature> Signature;
-  std::optional<DXContainerYAML::RootSignatureDesc> RootSignature;
+  std::optional<DXContainerYAML::RootSignatureYamlDesc> RootSignature;
 };
 
 struct Object {
@@ -259,9 +259,9 @@ template <> struct MappingTraits<DXContainerYAML::Signature> {
   static void mapping(IO &IO, llvm::DXContainerYAML::Signature &El);
 };
 
-template <> struct MappingTraits<DXContainerYAML::RootSignatureDesc> {
+template <> struct MappingTraits<DXContainerYAML::RootSignatureYamlDesc> {
   static void mapping(IO &IO,
-                      DXContainerYAML::RootSignatureDesc &RootSignature);
+                      DXContainerYAML::RootSignatureYamlDesc &RootSignature);
 };
 
 } // namespace yaml

diff  --git a/llvm/lib/MC/DXContainerRootSignature.cpp b/llvm/lib/MC/DXContainerRootSignature.cpp
index 000d23f24d241..b6f2b85bac74e 100644
--- a/llvm/lib/MC/DXContainerRootSignature.cpp
+++ b/llvm/lib/MC/DXContainerRootSignature.cpp
@@ -12,7 +12,7 @@
 using namespace llvm;
 using namespace llvm::mcdxbc;
 
-void RootSignatureHeader::write(raw_ostream &OS) {
+void RootSignatureDesc::write(raw_ostream &OS) const {
 
   support::endian::write(OS, Version, llvm::endianness::little);
   support::endian::write(OS, NumParameters, llvm::endianness::little);

diff  --git a/llvm/lib/Object/DXContainer.cpp b/llvm/lib/Object/DXContainer.cpp
index f28b096008b2f..1eb1453c65147 100644
--- a/llvm/lib/Object/DXContainer.cpp
+++ b/llvm/lib/Object/DXContainer.cpp
@@ -20,6 +20,10 @@ static Error parseFailed(const Twine &Msg) {
   return make_error<GenericBinaryError>(Msg.str(), object_error::parse_failed);
 }
 
+static Error validationFailed(const Twine &Msg) {
+  return make_error<StringError>(Msg.str(), inconvertibleErrorCode());
+}
+
 template <typename T>
 static Error readStruct(StringRef Buffer, const char *Src, T &Struct) {
   // Don't read before the beginning or past the end of the file
@@ -254,11 +258,10 @@ Error DirectX::RootSignature::parse(StringRef Data) {
       support::endian::read<uint32_t, llvm::endianness::little>(Current);
   Current += sizeof(uint32_t);
 
-  Expected<uint32_t> MaybeVersion =
-      dxbc::RootSignatureValidations::validateVersion(VValue);
-  if (Error E = MaybeVersion.takeError())
-    return E;
-  Version = MaybeVersion.get();
+  if (!dxbc::RootSignatureValidations::isValidVersion(VValue))
+    return validationFailed("unsupported root signature version read: " +
+                            llvm::Twine(VValue));
+  Version = VValue;
 
   NumParameters =
       support::endian::read<uint32_t, llvm::endianness::little>(Current);
@@ -280,11 +283,10 @@ Error DirectX::RootSignature::parse(StringRef Data) {
       support::endian::read<uint32_t, llvm::endianness::little>(Current);
   Current += sizeof(uint32_t);
 
-  Expected<uint32_t> MaybeFlag =
-      dxbc::RootSignatureValidations::validateRootFlag(FValue);
-  if (Error E = MaybeFlag.takeError())
-    return E;
-  Flags = MaybeFlag.get();
+  if (!dxbc::RootSignatureValidations::isValidRootFlag(FValue))
+    return validationFailed("unsupported root signature flag value read: " +
+                            llvm::Twine(FValue));
+  Flags = FValue;
 
   return Error::success();
 }

diff  --git a/llvm/lib/ObjectYAML/DXContainerEmitter.cpp b/llvm/lib/ObjectYAML/DXContainerEmitter.cpp
index b7d1c6558fa1f..f6ed09c857bb7 100644
--- a/llvm/lib/ObjectYAML/DXContainerEmitter.cpp
+++ b/llvm/lib/ObjectYAML/DXContainerEmitter.cpp
@@ -266,15 +266,15 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
       if (!P.RootSignature.has_value())
         continue;
 
-      mcdxbc::RootSignatureHeader Header;
-      Header.Flags = P.RootSignature->getEncodedFlags();
-      Header.Version = P.RootSignature->Version;
-      Header.NumParameters = P.RootSignature->NumParameters;
-      Header.RootParametersOffset = P.RootSignature->RootParametersOffset;
-      Header.NumStaticSamplers = P.RootSignature->NumStaticSamplers;
-      Header.StaticSamplersOffset = P.RootSignature->StaticSamplersOffset;
-
-      Header.write(OS);
+      mcdxbc::RootSignatureDesc RS;
+      RS.Flags = P.RootSignature->getEncodedFlags();
+      RS.Version = P.RootSignature->Version;
+      RS.NumParameters = P.RootSignature->NumParameters;
+      RS.RootParametersOffset = P.RootSignature->RootParametersOffset;
+      RS.NumStaticSamplers = P.RootSignature->NumStaticSamplers;
+      RS.StaticSamplersOffset = P.RootSignature->StaticSamplersOffset;
+
+      RS.write(OS);
       break;
     }
     uint64_t BytesWritten = OS.tell() - DataStart;

diff  --git a/llvm/lib/ObjectYAML/DXContainerYAML.cpp b/llvm/lib/ObjectYAML/DXContainerYAML.cpp
index 0869fd4fa9785..f03c7da65999d 100644
--- a/llvm/lib/ObjectYAML/DXContainerYAML.cpp
+++ b/llvm/lib/ObjectYAML/DXContainerYAML.cpp
@@ -29,7 +29,7 @@ DXContainerYAML::ShaderFeatureFlags::ShaderFeatureFlags(uint64_t FlagData) {
 #include "llvm/BinaryFormat/DXContainerConstants.def"
 }
 
-DXContainerYAML::RootSignatureDesc::RootSignatureDesc(
+DXContainerYAML::RootSignatureYamlDesc::RootSignatureYamlDesc(
     const object::DirectX::RootSignature &Data)
     : Version(Data.getVersion()), NumParameters(Data.getNumParameters()),
       RootParametersOffset(Data.getRootParametersOffset()),
@@ -41,7 +41,7 @@ DXContainerYAML::RootSignatureDesc::RootSignatureDesc(
 #include "llvm/BinaryFormat/DXContainerConstants.def"
 }
 
-uint32_t DXContainerYAML::RootSignatureDesc::getEncodedFlags() {
+uint32_t DXContainerYAML::RootSignatureYamlDesc::getEncodedFlags() {
   uint64_t Flag = 0;
 #define ROOT_ELEMENT_FLAG(Num, Val)                                            \
   if (Val)                                                                     \
@@ -209,8 +209,8 @@ void MappingTraits<DXContainerYAML::Signature>::mapping(
   IO.mapRequired("Parameters", S.Parameters);
 }
 
-void MappingTraits<DXContainerYAML::RootSignatureDesc>::mapping(
-    IO &IO, DXContainerYAML::RootSignatureDesc &S) {
+void MappingTraits<DXContainerYAML::RootSignatureYamlDesc>::mapping(
+    IO &IO, DXContainerYAML::RootSignatureYamlDesc &S) {
   IO.mapRequired("Version", S.Version);
   IO.mapRequired("NumParameters", S.NumParameters);
   IO.mapRequired("RootParametersOffset", S.RootParametersOffset);

diff  --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt
index 26315db891b57..5a167535b0afa 100644
--- a/llvm/lib/Target/DirectX/CMakeLists.txt
+++ b/llvm/lib/Target/DirectX/CMakeLists.txt
@@ -33,7 +33,8 @@ add_llvm_target(DirectXCodeGen
   DXILResourceAccess.cpp
   DXILShaderFlags.cpp
   DXILTranslateMetadata.cpp
-
+  DXILRootSignature.cpp
+  
   LINK_COMPONENTS
   Analysis
   AsmPrinter

diff  --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
index 7a0bd6a7c8869..5508af40663b1 100644
--- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
+++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
@@ -10,6 +10,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "DXILRootSignature.h"
 #include "DXILShaderFlags.h"
 #include "DirectX.h"
 #include "llvm/ADT/SmallVector.h"
@@ -25,7 +26,9 @@
 #include "llvm/MC/DXContainerPSVInfo.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/MD5.h"
+#include "llvm/TargetParser/Triple.h"
 #include "llvm/Transforms/Utils/ModuleUtils.h"
+#include <optional>
 
 using namespace llvm;
 using namespace llvm::dxil;
@@ -41,6 +44,7 @@ class DXContainerGlobals : public llvm::ModulePass {
   GlobalVariable *buildSignature(Module &M, Signature &Sig, StringRef Name,
                                  StringRef SectionName);
   void addSignature(Module &M, SmallVector<GlobalValue *> &Globals);
+  void addRootSignature(Module &M, SmallVector<GlobalValue *> &Globals);
   void addResourcesForPSV(Module &M, PSVRuntimeInfo &PSV);
   void addPipelineStateValidationInfo(Module &M,
                                       SmallVector<GlobalValue *> &Globals);
@@ -60,6 +64,7 @@ class DXContainerGlobals : public llvm::ModulePass {
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     AU.setPreservesAll();
     AU.addRequired<ShaderFlagsAnalysisWrapper>();
+    AU.addRequired<RootSignatureAnalysisWrapper>();
     AU.addRequired<DXILMetadataAnalysisWrapperPass>();
     AU.addRequired<DXILResourceTypeWrapperPass>();
     AU.addRequired<DXILResourceBindingWrapperPass>();
@@ -73,6 +78,7 @@ bool DXContainerGlobals::runOnModule(Module &M) {
   Globals.push_back(getFeatureFlags(M));
   Globals.push_back(computeShaderHash(M));
   addSignature(M, Globals);
+  addRootSignature(M, Globals);
   addPipelineStateValidationInfo(M, Globals);
   appendToCompilerUsed(M, Globals);
   return true;
@@ -144,6 +150,36 @@ void DXContainerGlobals::addSignature(Module &M,
   Globals.emplace_back(buildSignature(M, OutputSig, "dx.osg1", "OSG1"));
 }
 
+void DXContainerGlobals::addRootSignature(Module &M,
+                                          SmallVector<GlobalValue *> &Globals) {
+
+  dxil::ModuleMetadataInfo &MMI =
+      getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
+
+  // Root Signature in Library don't compile to DXContainer.
+  if (MMI.ShaderProfile == llvm::Triple::Library)
+    return;
+
+  assert(MMI.EntryPropertyVec.size() == 1);
+
+  auto &RSA = getAnalysis<RootSignatureAnalysisWrapper>();
+  const Function *EntryFunction = MMI.EntryPropertyVec[0].Entry;
+  const auto &FuncRs = RSA.find(EntryFunction);
+
+  if (FuncRs == RSA.end())
+    return;
+
+  const RootSignatureDesc &RS = FuncRs->second;
+  SmallString<256> Data;
+  raw_svector_ostream OS(Data);
+
+  RS.write(OS);
+
+  Constant *Constant =
+      ConstantDataArray::getString(M.getContext(), Data, /*AddNull*/ false);
+  Globals.emplace_back(buildContainerGlobal(M, Constant, "dx.rts0", "RTS0"));
+}
+
 void DXContainerGlobals::addResourcesForPSV(Module &M, PSVRuntimeInfo &PSV) {
   const DXILBindingMap &DBM =
       getAnalysis<DXILResourceBindingWrapperPass>().getBindingMap();

diff  --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
new file mode 100644
index 0000000000000..49fc892eade5d
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
@@ -0,0 +1,229 @@
+//===- DXILRootSignature.cpp - DXIL Root Signature helper objects -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file This file contains helper objects and APIs for working with DXIL
+///       Root Signatures.
+///
+//===----------------------------------------------------------------------===//
+#include "DXILRootSignature.h"
+#include "DirectX.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Analysis/DXILMetadataAnalysis.h"
+#include "llvm/BinaryFormat/DXContainer.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DiagnosticInfo.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Metadata.h"
+#include "llvm/IR/Module.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
+#include <cstdint>
+#include <optional>
+#include <utility>
+
+using namespace llvm;
+using namespace llvm::dxil;
+
+static bool reportError(LLVMContext *Ctx, Twine Message,
+                        DiagnosticSeverity Severity = DS_Error) {
+  Ctx->diagnose(DiagnosticInfoGeneric(Message, Severity));
+  return true;
+}
+
+static bool parseRootFlags(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
+                           MDNode *RootFlagNode) {
+
+  if (RootFlagNode->getNumOperands() != 2)
+    return reportError(Ctx, "Invalid format for RootFlag Element");
+
+  auto *Flag = mdconst::extract<ConstantInt>(RootFlagNode->getOperand(1));
+  RSD.Flags = Flag->getZExtValue();
+
+  return false;
+}
+
+static bool parseRootSignatureElement(LLVMContext *Ctx,
+                                      mcdxbc::RootSignatureDesc &RSD,
+                                      MDNode *Element) {
+  MDString *ElementText = cast<MDString>(Element->getOperand(0));
+  if (ElementText == nullptr)
+    return reportError(Ctx, "Invalid format for Root Element");
+
+  RootSignatureElementKind ElementKind =
+      StringSwitch<RootSignatureElementKind>(ElementText->getString())
+          .Case("RootFlags", RootSignatureElementKind::RootFlags)
+          .Default(RootSignatureElementKind::Error);
+
+  switch (ElementKind) {
+
+  case RootSignatureElementKind::RootFlags:
+    return parseRootFlags(Ctx, RSD, Element);
+  case RootSignatureElementKind::Error:
+    return reportError(Ctx, "Invalid Root Signature Element: " +
+                                ElementText->getString());
+  }
+
+  llvm_unreachable("Unhandled RootSignatureElementKind enum.");
+}
+
+static bool parse(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
+                  MDNode *Node) {
+  bool HasError = false;
+
+  // Loop through the Root Elements of the root signature.
+  for (const auto &Operand : Node->operands()) {
+    MDNode *Element = dyn_cast<MDNode>(Operand);
+    if (Element == nullptr)
+      return reportError(Ctx, "Missing Root Element Metadata Node.");
+
+    HasError = HasError || parseRootSignatureElement(Ctx, RSD, Element);
+  }
+
+  return HasError;
+}
+
+static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
+  if (!dxbc::RootSignatureValidations::isValidRootFlag(RSD.Flags)) {
+    return reportError(Ctx, "Invalid Root Signature flag value");
+  }
+  return false;
+}
+
+static SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>
+analyzeModule(Module &M) {
+
+  /** Root Signature are specified as following in the metadata:
+
+    !dx.rootsignatures = !{!2} ; list of function/root signature pairs
+    !2 = !{ ptr @main, !3 } ; function, root signature
+    !3 = !{ !4, !5, !6, !7 } ; list of root signature elements
+
+    So for each MDNode inside dx.rootsignatures NamedMDNode
+    (the Root parameter of this function), the parsing process needs
+    to loop through each of its operands and process the function,
+    signature pair.
+ */
+
+  LLVMContext *Ctx = &M.getContext();
+
+  SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> RSDMap;
+
+  NamedMDNode *RootSignatureNode = M.getNamedMetadata("dx.rootsignatures");
+  if (RootSignatureNode == nullptr)
+    return RSDMap;
+
+  for (const auto &RSDefNode : RootSignatureNode->operands()) {
+    if (RSDefNode->getNumOperands() != 2) {
+      reportError(Ctx, "Invalid format for Root Signature Definition. Pairs "
+                       "of function, root signature expected.");
+      continue;
+    }
+
+    // Function was pruned during compilation.
+    const MDOperand &FunctionPointerMdNode = RSDefNode->getOperand(0);
+    if (FunctionPointerMdNode == nullptr) {
+      reportError(
+          Ctx, "Function associated with Root Signature definition is null.");
+      continue;
+    }
+
+    ValueAsMetadata *VAM =
+        llvm::dyn_cast<ValueAsMetadata>(FunctionPointerMdNode.get());
+    if (VAM == nullptr) {
+      reportError(Ctx, "First element of root signature is not a Value");
+      continue;
+    }
+
+    Function *F = dyn_cast<Function>(VAM->getValue());
+    if (F == nullptr) {
+      reportError(Ctx, "First element of root signature is not a Function");
+      continue;
+    }
+
+    MDNode *RootElementListNode =
+        dyn_cast<MDNode>(RSDefNode->getOperand(1).get());
+
+    if (RootElementListNode == nullptr) {
+      reportError(Ctx, "Missing Root Element List Metadata node.");
+    }
+
+    mcdxbc::RootSignatureDesc RSD;
+
+    if (parse(Ctx, RSD, RootElementListNode) || validate(Ctx, RSD)) {
+      return RSDMap;
+    }
+
+    RSDMap.insert(std::make_pair(F, RSD));
+  }
+
+  return RSDMap;
+}
+
+AnalysisKey RootSignatureAnalysis::Key;
+
+SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>
+RootSignatureAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
+  return analyzeModule(M);
+}
+
+//===----------------------------------------------------------------------===//
+
+PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
+                                                    ModuleAnalysisManager &AM) {
+
+  SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> &RSDMap =
+      AM.getResult<RootSignatureAnalysis>(M);
+  OS << "Root Signature Definitions"
+     << "\n";
+  uint8_t Space = 0;
+  for (const auto &P : RSDMap) {
+    const auto &[Function, RSD] = P;
+    OS << "Definition for '" << Function->getName() << "':\n";
+
+    // start root signature header
+    Space++;
+    OS << indent(Space) << "Flags: " << format_hex(RSD.Flags, 8) << ":\n";
+    OS << indent(Space) << "Version: " << RSD.Version << ":\n";
+    OS << indent(Space) << "NumParameters: " << RSD.NumParameters << ":\n";
+    OS << indent(Space) << "RootParametersOffset: " << RSD.RootParametersOffset
+       << ":\n";
+    OS << indent(Space) << "NumStaticSamplers: " << RSD.NumStaticSamplers
+       << ":\n";
+    OS << indent(Space) << "StaticSamplersOffset: " << RSD.StaticSamplersOffset
+       << ":\n";
+    Space--;
+    // end root signature header
+  }
+
+  return PreservedAnalyses::all();
+}
+
+//===----------------------------------------------------------------------===//
+bool RootSignatureAnalysisWrapper::runOnModule(Module &M) {
+  FuncToRsMap = analyzeModule(M);
+  return false;
+}
+
+void RootSignatureAnalysisWrapper::getAnalysisUsage(AnalysisUsage &AU) const {
+  AU.setPreservesAll();
+  AU.addRequired<DXILMetadataAnalysisWrapperPass>();
+}
+
+char RootSignatureAnalysisWrapper::ID = 0;
+
+INITIALIZE_PASS_BEGIN(RootSignatureAnalysisWrapper,
+                      "dxil-root-signature-analysis",
+                      "DXIL Root Signature Analysis", true, true)
+INITIALIZE_PASS_END(RootSignatureAnalysisWrapper,
+                    "dxil-root-signature-analysis",
+                    "DXIL Root Signature Analysis", true, true)

diff  --git a/llvm/lib/Target/DirectX/DXILRootSignature.h b/llvm/lib/Target/DirectX/DXILRootSignature.h
new file mode 100644
index 0000000000000..8c25b2eb3fadf
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.h
@@ -0,0 +1,77 @@
+//===- DXILRootSignature.h - DXIL Root Signature helper objects -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file This file contains helper objects and APIs for working with DXIL
+///       Root Signatures.
+///
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Analysis/DXILMetadataAnalysis.h"
+#include "llvm/IR/DiagnosticInfo.h"
+#include "llvm/IR/Metadata.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/MC/DXContainerRootSignature.h"
+#include "llvm/Pass.h"
+#include <optional>
+
+namespace llvm {
+namespace dxil {
+
+enum class RootSignatureElementKind { Error = 0, RootFlags = 1 };
+class RootSignatureAnalysis : public AnalysisInfoMixin<RootSignatureAnalysis> {
+  friend AnalysisInfoMixin<RootSignatureAnalysis>;
+  static AnalysisKey Key;
+
+public:
+  RootSignatureAnalysis() = default;
+
+  using Result = SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>;
+
+  SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>
+  run(Module &M, ModuleAnalysisManager &AM);
+};
+
+/// Wrapper pass for the legacy pass manager.
+///
+/// This is required because the passes that will depend on this are codegen
+/// passes which run through the legacy pass manager.
+class RootSignatureAnalysisWrapper : public ModulePass {
+private:
+  SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc> FuncToRsMap;
+
+public:
+  static char ID;
+
+  RootSignatureAnalysisWrapper() : ModulePass(ID) {}
+
+  using iterator =
+      SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>::iterator;
+
+  iterator find(const Function *F) { return FuncToRsMap.find(F); }
+
+  iterator end() { return FuncToRsMap.end(); }
+
+  bool runOnModule(Module &M) override;
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override;
+};
+
+/// Printer pass for RootSignatureAnalysis results.
+class RootSignatureAnalysisPrinter
+    : public PassInfoMixin<RootSignatureAnalysisPrinter> {
+  raw_ostream &OS;
+
+public:
+  explicit RootSignatureAnalysisPrinter(raw_ostream &OS) : OS(OS) {}
+  PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
+};
+
+} // namespace dxil
+} // namespace llvm

diff  --git a/llvm/lib/Target/DirectX/DirectX.h b/llvm/lib/Target/DirectX/DirectX.h
index add23587de7d5..953ac3eb82098 100644
--- a/llvm/lib/Target/DirectX/DirectX.h
+++ b/llvm/lib/Target/DirectX/DirectX.h
@@ -77,6 +77,9 @@ void initializeDXILPrettyPrinterLegacyPass(PassRegistry &);
 /// Initializer for dxil::ShaderFlagsAnalysisWrapper pass.
 void initializeShaderFlagsAnalysisWrapperPass(PassRegistry &);
 
+/// Initializer for dxil::RootSignatureAnalysisWrapper pass.
+void initializeRootSignatureAnalysisWrapperPass(PassRegistry &);
+
 /// Initializer for DXContainerGlobals pass.
 void initializeDXContainerGlobalsPass(PassRegistry &);
 

diff  --git a/llvm/lib/Target/DirectX/DirectXPassRegistry.def b/llvm/lib/Target/DirectX/DirectXPassRegistry.def
index 87591b104ce52..de5087ce1ae2f 100644
--- a/llvm/lib/Target/DirectX/DirectXPassRegistry.def
+++ b/llvm/lib/Target/DirectX/DirectXPassRegistry.def
@@ -18,6 +18,7 @@
 #endif
 MODULE_ANALYSIS("dx-shader-flags", dxil::ShaderFlagsAnalysis())
 MODULE_ANALYSIS("dxil-resource-md", DXILResourceMDAnalysis())
+MODULE_ANALYSIS("dxil-root-signature-analysis", dxil::RootSignatureAnalysis())
 #undef MODULE_ANALYSIS
 
 #ifndef MODULE_PASS
@@ -31,6 +32,7 @@ MODULE_PASS("dxil-pretty-printer", DXILPrettyPrinterPass(dbgs()))
 MODULE_PASS("dxil-translate-metadata", DXILTranslateMetadata())
 // TODO: rename to print<foo> after NPM switch
 MODULE_PASS("print-dx-shader-flags", dxil::ShaderFlagsAnalysisPrinter(dbgs()))
+MODULE_PASS("print<dxil-root-signature>", dxil::RootSignatureAnalysisPrinter(dbgs()))
 #undef MODULE_PASS
 
 #ifndef FUNCTION_PASS

diff  --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
index ecb1bf775f857..a76c07f784177 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
@@ -19,6 +19,7 @@
 #include "DXILPrettyPrinter.h"
 #include "DXILResourceAccess.h"
 #include "DXILResourceAnalysis.h"
+#include "DXILRootSignature.h"
 #include "DXILShaderFlags.h"
 #include "DXILTranslateMetadata.h"
 #include "DXILWriter/DXILWriterPass.h"
@@ -61,6 +62,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
   initializeDXILTranslateMetadataLegacyPass(*PR);
   initializeDXILResourceMDWrapperPass(*PR);
   initializeShaderFlagsAnalysisWrapperPass(*PR);
+  initializeRootSignatureAnalysisWrapperPass(*PR);
   initializeDXILFinalizeLinkageLegacyPass(*PR);
 }
 

diff  --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Error.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Error.ll
new file mode 100644
index 0000000000000..2a2188b1a13bb
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Error.ll
@@ -0,0 +1,18 @@
+; RUN: not opt -passes='print<dxil-root-signature>' %s -S -o - 2>&1 | FileCheck %s
+
+target triple = "dxil-unknown-shadermodel6.0-compute"
+
+; CHECK: error: Invalid format for Root Signature Definition. Pairs of function, root signature expected.
+; CHECK-NOT: Root Signature Definitions
+
+
+define void @main() #0 {
+entry:
+  ret void
+}
+
+attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
+
+
+!dx.rootsignatures = !{!1} ; list of function/root signature pairs
+!1= !{ !"RootFlags" } ; function, root signature

diff  --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags-Error.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags-Error.ll
new file mode 100644
index 0000000000000..4921472d253ad
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags-Error.ll
@@ -0,0 +1,20 @@
+; RUN: not opt -passes='print<dxil-root-signature>' %s -S -o - 2>&1 | FileCheck %s
+
+target triple = "dxil-unknown-shadermodel6.0-compute"
+
+; CHECK: error: Invalid Root Signature Element: NOTRootFlags
+; CHECK-NOT: Root Signature Definitions
+
+
+define void @main() #0 {
+entry:
+  ret void
+}
+
+attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
+
+
+!dx.rootsignatures = !{!2} ; list of function/root signature pairs
+!2 = !{ ptr @main, !3 } ; function, root signature
+!3 = !{ !4 } ; list of root signature elements
+!4 = !{ !"NOTRootFlags", i32 1 } ; 1 = allow_input_assembler_input_layout

diff  --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags-Validation-Error.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags-Validation-Error.ll
new file mode 100644
index 0000000000000..fe93c9993c1c3
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags-Validation-Error.ll
@@ -0,0 +1,20 @@
+; RUN: not opt -passes='print<dxil-root-signature>' %s -S -o - 2>&1 | FileCheck %s
+
+; CHECK: error: Invalid Root Signature flag value
+; CHECK-NOT: Root Signature Definitions
+
+target triple = "dxil-unknown-shadermodel6.0-compute"
+
+
+define void @main() #0 {
+entry:
+  ret void
+}
+
+attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
+
+
+!dx.rootsignatures = !{!2} ; list of function/root signature pairs
+!2 = !{ ptr @main, !3 } ; function, root signature
+!3 = !{ !4 } ; list of root signature elements
+!4 = !{ !"RootFlags", i32 2147487744 } ; 1 = allow_input_assembler_input_layout

diff  --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags.ll
new file mode 100644
index 0000000000000..3f5bb166ad0e5
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags.ll
@@ -0,0 +1,29 @@
+; RUN: opt %s -dxil-embed -dxil-globals -S -o - | FileCheck %s
+; RUN: llc %s --filetype=obj -o - | obj2yaml | FileCheck %s --check-prefix=DXC
+
+target triple = "dxil-unknown-shadermodel6.0-compute"
+
+; CHECK: @dx.rts0 = private constant [24 x i8]  c"{{.*}}", section "RTS0", align 4
+
+define void @main() #0 {
+entry:
+  ret void
+}
+attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
+
+
+!dx.rootsignatures = !{!2} ; list of function/root signature pairs
+!2 = !{ ptr @main, !3 } ; function, root signature
+!3 = !{ !4 } ; list of root signature elements
+!4 = !{ !"RootFlags", i32 1 } ; 1 = allow_input_assembler_input_layout
+
+
+; DXC:  - Name:            RTS0
+; DXC-NEXT:    Size:            24
+; DXC-NEXT:    RootSignature:
+; DXC-NEXT:      Version:         2
+; DXC-NEXT:      NumParameters:   0
+; DXC-NEXT:      RootParametersOffset: 0
+; DXC-NEXT:      NumStaticSamplers: 0
+; DXC-NEXT:      StaticSamplersOffset: 0
+; DXC-NEXT:      AllowInputAssemblerInputLayout: true

diff  --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-MultipleEntryFunctions.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-MultipleEntryFunctions.ll
new file mode 100644
index 0000000000000..652f8092b7a69
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-MultipleEntryFunctions.ll
@@ -0,0 +1,41 @@
+; RUN: opt -passes='print<dxil-root-signature>' %s -S -o - 2>&1 | FileCheck %s
+
+target triple = "dxil-unknown-shadermodel6.0-compute"
+
+
+define void @main() #0 {
+entry:
+  ret void
+}
+
+define void @anotherMain() #0 {
+entry:
+  ret void
+}
+
+attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
+
+!dx.rootsignatures = !{!2, !5} ; list of function/root signature pairs
+!2 = !{ ptr @main, !3 } ; function, root signature
+!3 = !{ !4 } ; list of root signature elements
+!4 = !{ !"RootFlags", i32 1 } ; 1 = allow_input_assembler_input_layout
+!5 = !{ ptr @anotherMain, !6 } ; function, root signature
+!6 = !{ !7 } ; list of root signature elements
+!7 = !{ !"RootFlags", i32 2 } ; 1 = allow_input_assembler_input_layout
+
+
+; CHECK-LABEL: Definition for 'main':
+; CHECK-NEXT:   Flags: 0x000001
+; CHECK-NEXT:   Version: 2
+; CHECK-NEXT:   NumParameters: 0
+; CHECK-NEXT:   RootParametersOffset: 0
+; CHECK-NEXT:   NumStaticSamplers: 0
+; CHECK-NEXT:   StaticSamplersOffset: 0
+
+; CHECK-LABEL: Definition for 'anotherMain':
+; CHECK-NEXT:   Flags: 0x000002
+; CHECK-NEXT:   Version: 2
+; CHECK-NEXT:   NumParameters: 0
+; CHECK-NEXT:   RootParametersOffset: 0
+; CHECK-NEXT:   NumStaticSamplers: 0
+; CHECK-NEXT:   StaticSamplersOffset: 0

diff  --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-NullFunction-Error.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-NullFunction-Error.ll
new file mode 100644
index 0000000000000..f5caa50124788
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-NullFunction-Error.ll
@@ -0,0 +1,21 @@
+; RUN: not opt -passes='print<dxil-root-signature>' %s -S -o - 2>&1 | FileCheck %s
+
+; CHECK: error: Function associated with Root Signature definition is null
+; CHECK-NOT: Root Signature Definitions
+
+target triple = "dxil-unknown-shadermodel6.0-compute"
+
+define void @main() #0 {
+entry:
+  ret void
+}
+
+attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
+
+!dx.rootsignatures = !{!2, !5} ; list of function/root signature pairs
+!2 = !{ ptr @main, !3 } ; function, root signature
+!3 = !{ !4 } ; list of root signature elements
+!4 = !{ !"RootFlags", i32 1 } ; 1 = allow_input_assembler_input_layout
+!5 = !{ null, !6 } ; function, root signature
+!6 = !{ !7 } ; list of root signature elements
+!7 = !{ !"RootFlags", i32 2 } ; 1 = allow_input_assembler_input_layout

diff  --git a/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootElement-Error.ll b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootElement-Error.ll
new file mode 100644
index 0000000000000..89e23f6540c5f
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ContainerData/RootSignature-RootElement-Error.ll
@@ -0,0 +1,19 @@
+; RUN: not opt -passes='print<dxil-root-signature>' %s -S -o - 2>&1 | FileCheck %s
+
+target triple = "dxil-unknown-shadermodel6.0-compute"
+
+; CHECK: error: Missing Root Element Metadata Node.
+; CHECK-NOT: Root Signature Definitions
+
+
+define void @main() #0 {
+entry:
+  ret void
+}
+
+attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
+
+
+!dx.rootsignatures = !{!2} ; list of function/root signature pairs
+!2 = !{ ptr @main, !3 } ; function, root signature
+!3 = !{ !"NOTRootElements" } ; list of root signature elements

diff  --git a/llvm/test/CodeGen/DirectX/llc-pipeline.ll b/llvm/test/CodeGen/DirectX/llc-pipeline.ll
index 03b2150bbc1dc..afbf1ff72ec7e 100644
--- a/llvm/test/CodeGen/DirectX/llc-pipeline.ll
+++ b/llvm/test/CodeGen/DirectX/llc-pipeline.ll
@@ -33,6 +33,7 @@
 ; CHECK-ASM-NEXT: Print Module IR
 
 ; CHECK-OBJ-NEXT: DXIL Embedder
+; CHECK-OBJ-NEXT: DXIL Root Signature Analysis
 ; CHECK-OBJ-NEXT: DXContainer Global Emitter
 ; CHECK-OBJ-NEXT: FunctionPass Manager
 ; CHECK-OBJ-NEXT:   Lazy Machine Block Frequency Analysis

diff  --git a/llvm/tools/obj2yaml/dxcontainer2yaml.cpp b/llvm/tools/obj2yaml/dxcontainer2yaml.cpp
index 54a912d9438af..f3ef1b6a27bcf 100644
--- a/llvm/tools/obj2yaml/dxcontainer2yaml.cpp
+++ b/llvm/tools/obj2yaml/dxcontainer2yaml.cpp
@@ -156,7 +156,7 @@ dumpDXContainer(MemoryBufferRef Source) {
     case dxbc::PartType::RTS0:
       std::optional<DirectX::RootSignature> RS = Container.getRootSignature();
       if (RS.has_value())
-        NewPart.RootSignature = DXContainerYAML::RootSignatureDesc(*RS);
+        NewPart.RootSignature = DXContainerYAML::RootSignatureYamlDesc(*RS);
       break;
     }
   }

diff  --git a/llvm/unittests/Object/DXContainerTest.cpp b/llvm/unittests/Object/DXContainerTest.cpp
index e7b491103d2d0..5a73f32ab7c32 100644
--- a/llvm/unittests/Object/DXContainerTest.cpp
+++ b/llvm/unittests/Object/DXContainerTest.cpp
@@ -871,9 +871,8 @@ TEST(RootSignature, ParseRootFlags) {
         0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
     };
     EXPECT_THAT_EXPECTED(
-        DXContainer::create(getMemoryBuffer<68>(Buffer)),
-        FailedWithMessage("Stream Error: An unspecified error has occurred.  "
-                          "Invalid Root Signature Version"));
+        DXContainer::create(getMemoryBuffer<100>(Buffer)),
+        FailedWithMessage("unsupported root signature version read: 3"));
   }
   {
     // Flag has been set to an invalid value
@@ -886,8 +885,8 @@ TEST(RootSignature, ParseRootFlags) {
         0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0xFF,
     };
     EXPECT_THAT_EXPECTED(
-        DXContainer::create(getMemoryBuffer<68>(Buffer)),
-        FailedWithMessage("Stream Error: An unspecified error has occurred.  "
-                          "Invalid Root Signature flag"));
+        DXContainer::create(getMemoryBuffer<100>(Buffer)),
+        FailedWithMessage(
+            "unsupported root signature flag value read: 4278190081"));
   }
 }


        


More information about the llvm-commits mailing list