[llvm] [DXIL] Adding support to RootSignatureFlags generation to DXContainer (PR #122396)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 9 17:19:49 PST 2025


https://github.com/joaosaffran created https://github.com/llvm/llvm-project/pull/122396

This PR adds:
- Root signature 1.0 definition for `RootSignatureFlags`
- Root Signature Generation to DX Container
- Root Signature `RootSignatureFlags` extraction from LLVM
- Root Signature generation to DXIL IR
- `RootSignatureFlags` Validation
- `RootSignatureFlags` extraction from DXContainer using `obj2yaml`

>From 155a5e3377d5a6102bf4a91604b796a4dd308456 Mon Sep 17 00:00:00 2001
From: joaosaffran <joao.saffran at microsoft.com>
Date: Mon, 6 Jan 2025 21:41:59 +0000
Subject: [PATCH] Extracting Root flags root signature element from llvm ir
 metadata

---
 .../llvm/Analysis/DXILMetadataAnalysis.h      |   2 +
 .../include/llvm/Analysis/DXILRootSignature.h |  88 ++++++++++++++
 .../BinaryFormat/DXContainerConstants.def     |   1 +
 llvm/include/llvm/Object/DXContainer.h        |   8 ++
 .../include/llvm/ObjectYAML/DXContainerYAML.h |  14 +++
 llvm/lib/Analysis/CMakeLists.txt              |   1 +
 llvm/lib/Analysis/DXILMetadataAnalysis.cpp    |  17 +++
 llvm/lib/Analysis/DXILRootSignature.cpp       | 110 ++++++++++++++++++
 llvm/lib/Object/DXContainer.cpp               |  15 +++
 llvm/lib/ObjectYAML/DXContainerEmitter.cpp    |   7 ++
 llvm/lib/ObjectYAML/DXContainerYAML.cpp       |  68 +++++++++++
 .../lib/Target/DirectX/DXContainerGlobals.cpp |  20 ++++
 llvm/tools/obj2yaml/dxcontainer2yaml.cpp      |  22 ++++
 13 files changed, 373 insertions(+)
 create mode 100644 llvm/include/llvm/Analysis/DXILRootSignature.h
 create mode 100644 llvm/lib/Analysis/DXILRootSignature.cpp

diff --git a/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h b/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h
index cb535ac14f1c61..7731c781a48353 100644
--- a/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h
+++ b/llvm/include/llvm/Analysis/DXILMetadataAnalysis.h
@@ -10,6 +10,7 @@
 #define LLVM_ANALYSIS_DXILMETADATA_H
 
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Analysis/DXILRootSignature.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/VersionTuple.h"
@@ -37,6 +38,7 @@ struct ModuleMetadataInfo {
   Triple::EnvironmentType ShaderProfile{Triple::UnknownEnvironment};
   VersionTuple ValidatorVersion{};
   SmallVector<EntryProperties> EntryPropertyVec{};
+  root_signature::VersionedRootSignatureDesc RootSignatureDesc;
   void print(raw_ostream &OS) const;
 };
 
diff --git a/llvm/include/llvm/Analysis/DXILRootSignature.h b/llvm/include/llvm/Analysis/DXILRootSignature.h
new file mode 100644
index 00000000000000..cb3d6192f4404d
--- /dev/null
+++ b/llvm/include/llvm/Analysis/DXILRootSignature.h
@@ -0,0 +1,88 @@
+//===- DXILRootSignature.h - DXIL Root Signature helper objects -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file This file contains helper objects for working with DXIL Root
+/// Signatures.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_DIRECTX_HLSLROOTSIGNATURE_H
+#define LLVM_DIRECTX_HLSLROOTSIGNATURE_H
+
+#include "llvm/IR/Metadata.h"
+#include "llvm/Support/ScopedPrinter.h"
+namespace llvm {
+namespace dxil {
+namespace root_signature {
+
+enum class RootSignatureElementKind {
+  None = 0,
+  RootFlags = 1,
+  RootConstants = 2,
+  RootDescriptor = 3,
+  DescriptorTable = 4,
+  StaticSampler = 5
+};
+
+enum class RootSignatureVersion {
+  Version_1 = 1,
+  Version_1_0 = 1,
+  Version_1_1 = 2,
+  Version_1_2 = 3
+};
+
+enum RootSignatureFlags : uint32_t {
+  None = 0,
+  AllowInputAssemblerInputLayout = 0x1,
+  DenyVertexShaderRootAccess = 0x2,
+  DenyHullShaderRootAccess = 0x4,
+  DenyDomainShaderRootAccess = 0x8,
+  DenyGeometryShaderRootAccess = 0x10,
+  DenyPixelShaderRootAccess = 0x20,
+  AllowStreamOutput = 0x40,
+  LocalRootSignature = 0x80,
+  DenyAmplificationShaderRootAccess = 0x100,
+  DenyMeshShaderRootAccess = 0x200,
+  CBVSRVUAVHeapDirectlyIndexed = 0x400,
+  SamplerHeapDirectlyIndexed = 0x800,
+  AllowLowTierReservedHwCbLimit = 0x80000000,
+  ValidFlags = 0x80000fff
+};
+
+struct DxilRootSignatureDesc1_0 {
+  RootSignatureFlags Flags;
+};
+
+struct VersionedRootSignatureDesc {
+  RootSignatureVersion Version;
+  union {
+    DxilRootSignatureDesc1_0 Desc_1_0;
+  };
+
+  bool isPopulated();
+
+  void swapBytes();
+};
+
+class MetadataParser {
+public:
+  NamedMDNode *Root;
+  MetadataParser(NamedMDNode *Root) : Root(Root) {}
+
+  bool Parse(RootSignatureVersion Version, VersionedRootSignatureDesc *Desc);
+
+private:
+  bool ParseRootFlags(MDNode *RootFlagRoot, VersionedRootSignatureDesc *Desc);
+  bool ParseRootSignatureElement(MDNode *Element,
+                                 VersionedRootSignatureDesc *Desc);
+};
+} // namespace root_signature
+} // namespace dxil
+} // namespace llvm
+
+#endif // LLVM_DIRECTX_HLSLROOTSIGNATURE_H
diff --git a/llvm/include/llvm/BinaryFormat/DXContainerConstants.def b/llvm/include/llvm/BinaryFormat/DXContainerConstants.def
index 1aacbb2f65b27f..38b69228cd3975 100644
--- a/llvm/include/llvm/BinaryFormat/DXContainerConstants.def
+++ b/llvm/include/llvm/BinaryFormat/DXContainerConstants.def
@@ -4,6 +4,7 @@ CONTAINER_PART(DXIL)
 CONTAINER_PART(SFI0)
 CONTAINER_PART(HASH)
 CONTAINER_PART(PSV0)
+CONTAINER_PART(RTS0)
 CONTAINER_PART(ISG1)
 CONTAINER_PART(OSG1)
 CONTAINER_PART(PSG1)
diff --git a/llvm/include/llvm/Object/DXContainer.h b/llvm/include/llvm/Object/DXContainer.h
index 19c83ba6c6e85d..9a6aa8224eddf4 100644
--- a/llvm/include/llvm/Object/DXContainer.h
+++ b/llvm/include/llvm/Object/DXContainer.h
@@ -17,6 +17,7 @@
 
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/Analysis/DXILRootSignature.h"
 #include "llvm/BinaryFormat/DXContainer.h"
 #include "llvm/Support/Error.h"
 #include "llvm/Support/MemoryBufferRef.h"
@@ -287,6 +288,7 @@ class DXContainer {
   std::optional<uint64_t> ShaderFeatureFlags;
   std::optional<dxbc::ShaderHash> Hash;
   std::optional<DirectX::PSVRuntimeInfo> PSVInfo;
+  std::optional<dxil::root_signature::VersionedRootSignatureDesc> RootSignature;
   DirectX::Signature InputSignature;
   DirectX::Signature OutputSignature;
   DirectX::Signature PatchConstantSignature;
@@ -296,6 +298,7 @@ class DXContainer {
   Error parseDXILHeader(StringRef Part);
   Error parseShaderFeatureFlags(StringRef Part);
   Error parseHash(StringRef Part);
+  Error parseRootSignature(StringRef Part);
   Error parsePSVInfo(StringRef Part);
   Error parseSignature(StringRef Part, DirectX::Signature &Array);
   friend class PartIterator;
@@ -382,6 +385,11 @@ class DXContainer {
 
   std::optional<dxbc::ShaderHash> getShaderHash() const { return Hash; }
 
+  std::optional<dxil::root_signature::VersionedRootSignatureDesc>
+  getRootSignature() const {
+    return RootSignature;
+  }
+
   const std::optional<DirectX::PSVRuntimeInfo> &getPSVInfo() const {
     return PSVInfo;
   };
diff --git a/llvm/include/llvm/ObjectYAML/DXContainerYAML.h b/llvm/include/llvm/ObjectYAML/DXContainerYAML.h
index 66ad057ab0e30f..e9da51f61c0a2b 100644
--- a/llvm/include/llvm/ObjectYAML/DXContainerYAML.h
+++ b/llvm/include/llvm/ObjectYAML/DXContainerYAML.h
@@ -16,6 +16,7 @@
 #define LLVM_OBJECTYAML_DXCONTAINERYAML_H
 
 #include "llvm/ADT/StringRef.h"
+#include "llvm/Analysis/DXILRootSignature.h"
 #include "llvm/BinaryFormat/DXContainer.h"
 #include "llvm/ObjectYAML/YAML.h"
 #include "llvm/Support/YAMLTraits.h"
@@ -149,6 +150,13 @@ struct Signature {
   llvm::SmallVector<SignatureParameter> Parameters;
 };
 
+struct RootSignature {
+  RootSignature() = default;
+
+  dxil::root_signature::RootSignatureVersion Version;
+  dxil::root_signature::RootSignatureFlags Flags;
+};
+
 struct Part {
   Part() = default;
   Part(std::string N, uint32_t S) : Name(N), Size(S) {}
@@ -159,6 +167,7 @@ struct Part {
   std::optional<ShaderHash> Hash;
   std::optional<PSVInfo> Info;
   std::optional<DXContainerYAML::Signature> Signature;
+  std::optional<DXContainerYAML::RootSignature> RootSignature;
 };
 
 struct Object {
@@ -241,6 +250,11 @@ template <> struct MappingTraits<DXContainerYAML::Signature> {
   static void mapping(IO &IO, llvm::DXContainerYAML::Signature &El);
 };
 
+template <> struct MappingTraits<DXContainerYAML::RootSignature> {
+  static void mapping(IO &IO,
+                      llvm::DXContainerYAML::RootSignature &RootSignature);
+};
+
 } // namespace yaml
 
 } // namespace llvm
diff --git a/llvm/lib/Analysis/CMakeLists.txt b/llvm/lib/Analysis/CMakeLists.txt
index 0db5b80f336cb5..8875ddd34fe56c 100644
--- a/llvm/lib/Analysis/CMakeLists.txt
+++ b/llvm/lib/Analysis/CMakeLists.txt
@@ -62,6 +62,7 @@ add_llvm_component_library(LLVMAnalysis
   DominanceFrontier.cpp
   DXILResource.cpp
   DXILMetadataAnalysis.cpp
+  DXILRootSignature.cpp
   FunctionPropertiesAnalysis.cpp
   GlobalsModRef.cpp
   GuardUtils.cpp
diff --git a/llvm/lib/Analysis/DXILMetadataAnalysis.cpp b/llvm/lib/Analysis/DXILMetadataAnalysis.cpp
index a7f666a3f8b48f..3bd60bfe203f49 100644
--- a/llvm/lib/Analysis/DXILMetadataAnalysis.cpp
+++ b/llvm/lib/Analysis/DXILMetadataAnalysis.cpp
@@ -10,12 +10,15 @@
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/Analysis/DXILRootSignature.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/Module.h"
 #include "llvm/InitializePasses.h"
+#include "llvm/Support/Casting.h"
 #include "llvm/Support/ErrorHandling.h"
+#include <memory>
 
 #define DEBUG_TYPE "dxil-metadata-analysis"
 
@@ -28,6 +31,7 @@ static ModuleMetadataInfo collectMetadataInfo(Module &M) {
   MMDAI.DXILVersion = TT.getDXILVersion();
   MMDAI.ShaderModelVersion = TT.getOSVersion();
   MMDAI.ShaderProfile = TT.getEnvironment();
+
   NamedMDNode *ValidatorVerNode = M.getNamedMetadata("dx.valver");
   if (ValidatorVerNode) {
     auto *ValVerMD = cast<MDNode>(ValidatorVerNode->getOperand(0));
@@ -37,6 +41,19 @@ static ModuleMetadataInfo collectMetadataInfo(Module &M) {
         VersionTuple(MajorMD->getZExtValue(), MinorMD->getZExtValue());
   }
 
+  NamedMDNode *RootSignatureNode = M.getNamedMetadata("dx.rootsignatures");
+  if (RootSignatureNode) {
+    auto RootSignatureParser =
+        root_signature::MetadataParser(RootSignatureNode);
+
+    root_signature::VersionedRootSignatureDesc Desc;
+
+    RootSignatureParser.Parse(root_signature::RootSignatureVersion::Version_1,
+                              &Desc);
+
+    MMDAI.RootSignatureDesc = Desc;
+  }
+
   // For all HLSL Shader functions
   for (auto &F : M.functions()) {
     if (!F.hasFnAttribute("hlsl.shader"))
diff --git a/llvm/lib/Analysis/DXILRootSignature.cpp b/llvm/lib/Analysis/DXILRootSignature.cpp
new file mode 100644
index 00000000000000..fce97eb27cf8f8
--- /dev/null
+++ b/llvm/lib/Analysis/DXILRootSignature.cpp
@@ -0,0 +1,110 @@
+//===- DXILRootSignature.cpp - DXIL Root Signature helper objects
+//-----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file This file contains the parsing logic to extract root signature data
+///       from LLVM IR metadata.
+///
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/DXILRootSignature.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/Metadata.h"
+#include "llvm/Support/ErrorHandling.h"
+#include <cassert>
+
+namespace llvm {
+namespace dxil {
+
+bool root_signature::MetadataParser::Parse(RootSignatureVersion Version,
+                                           VersionedRootSignatureDesc *Desc) {
+  Desc->Version = Version;
+  bool HasError = false;
+
+  for (unsigned int Sid = 0; Sid < Root->getNumOperands(); Sid++) {
+    // This should be an if, for error handling
+    MDNode *Node = cast<MDNode>(Root->getOperand(Sid));
+
+    // Not sure what use this for...
+    Metadata *Func = Node->getOperand(0).get();
+
+    // This should be an if, for error handling
+    MDNode *Elements = cast<MDNode>(Node->getOperand(1).get());
+
+    for (unsigned int Eid = 0; Eid < Elements->getNumOperands(); Eid++) {
+      MDNode *Element = cast<MDNode>(Elements->getOperand(Eid));
+
+      HasError = HasError || ParseRootSignatureElement(Element, Desc);
+    }
+  }
+  return HasError;
+}
+
+bool root_signature::MetadataParser::ParseRootFlags(
+    MDNode *RootFlagNode, VersionedRootSignatureDesc *Desc) {
+
+  assert(RootFlagNode->getNumOperands() == 2 &&
+         "Invalid format for RootFlag Element");
+  auto *Flag = mdconst::extract<ConstantInt>(RootFlagNode->getOperand(1));
+  auto Value = (RootSignatureFlags)Flag->getZExtValue();
+
+  if ((Value & ~RootSignatureFlags::ValidFlags) != RootSignatureFlags::None)
+    return true;
+
+  switch (Desc->Version) {
+
+  case RootSignatureVersion::Version_1:
+    Desc->Desc_1_0.Flags = (RootSignatureFlags)Value;
+    break;
+  case RootSignatureVersion::Version_1_1:
+  case RootSignatureVersion::Version_1_2:
+    llvm_unreachable("Not implemented yet");
+    break;
+  }
+  return false;
+}
+
+bool root_signature::MetadataParser::ParseRootSignatureElement(
+    MDNode *Element, VersionedRootSignatureDesc *Desc) {
+  MDString *ElementText = cast<MDString>(Element->getOperand(0));
+
+  assert(ElementText != nullptr && "First preoperty of element is not ");
+
+  RootSignatureElementKind ElementKind =
+      StringSwitch<RootSignatureElementKind>(ElementText->getString())
+          .Case("RootFlags", RootSignatureElementKind::RootFlags)
+          .Case("RootConstants", RootSignatureElementKind::RootConstants)
+          .Case("RootCBV", RootSignatureElementKind::RootDescriptor)
+          .Case("RootSRV", RootSignatureElementKind::RootDescriptor)
+          .Case("RootUAV", RootSignatureElementKind::RootDescriptor)
+          .Case("Sampler", RootSignatureElementKind::RootDescriptor)
+          .Case("DescriptorTable", RootSignatureElementKind::DescriptorTable)
+          .Case("StaticSampler", RootSignatureElementKind::StaticSampler)
+          .Default(RootSignatureElementKind::None);
+
+  switch (ElementKind) {
+
+  case RootSignatureElementKind::RootFlags: {
+    return ParseRootFlags(Element, Desc);
+    break;
+  }
+
+  case RootSignatureElementKind::RootConstants:
+  case RootSignatureElementKind::RootDescriptor:
+  case RootSignatureElementKind::DescriptorTable:
+  case RootSignatureElementKind::StaticSampler:
+  case RootSignatureElementKind::None:
+    llvm_unreachable("Not Implemented yet");
+    break;
+  }
+
+  return true;
+}
+} // namespace dxil
+} // namespace llvm
diff --git a/llvm/lib/Object/DXContainer.cpp b/llvm/lib/Object/DXContainer.cpp
index 3b1a6203a1f8fc..f50f68df88ec2a 100644
--- a/llvm/lib/Object/DXContainer.cpp
+++ b/llvm/lib/Object/DXContainer.cpp
@@ -7,9 +7,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Object/DXContainer.h"
+#include "llvm/Analysis/DXILRootSignature.h"
 #include "llvm/BinaryFormat/DXContainer.h"
 #include "llvm/Object/Error.h"
 #include "llvm/Support/Alignment.h"
+#include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/FormatVariadic.h"
 
 using namespace llvm;
@@ -92,6 +94,14 @@ Error DXContainer::parseHash(StringRef Part) {
   return Error::success();
 }
 
+Error DXContainer::parseRootSignature(StringRef Part) {
+  dxil::root_signature::VersionedRootSignatureDesc Desc;
+  if (Error Err = readStruct(Part, Part.begin(), Desc))
+    return Err;
+  RootSignature = Desc;
+  return Error::success();
+}
+
 Error DXContainer::parsePSVInfo(StringRef Part) {
   if (PSVInfo)
     return parseFailed("More than one PSV0 part is present in the file");
@@ -192,6 +202,11 @@ Error DXContainer::parsePartOffsets() {
         return Err;
       break;
     case dxbc::PartType::Unknown:
+      break;
+    case dxbc::PartType::RTS0:
+      if (Error Err = parseRootSignature(PartData))
+        return Err;
+
       break;
     }
   }
diff --git a/llvm/lib/ObjectYAML/DXContainerEmitter.cpp b/llvm/lib/ObjectYAML/DXContainerEmitter.cpp
index 175f1a12f93145..7f576e8731a128 100644
--- a/llvm/lib/ObjectYAML/DXContainerEmitter.cpp
+++ b/llvm/lib/ObjectYAML/DXContainerEmitter.cpp
@@ -11,6 +11,7 @@
 ///
 //===----------------------------------------------------------------------===//
 
+#include "llvm/Analysis/DXILRootSignature.h"
 #include "llvm/BinaryFormat/DXContainer.h"
 #include "llvm/MC/DXContainerPSVInfo.h"
 #include "llvm/ObjectYAML/ObjectYAML.h"
@@ -261,6 +262,12 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
     }
     case dxbc::PartType::Unknown:
       break; // Skip any handling for unrecognized parts.
+    case dxbc::PartType::RTS0:
+      if (!P.RootSignature)
+        continue;
+      OS.write(reinterpret_cast<const char *>(&P.RootSignature),
+               sizeof(dxil::root_signature::VersionedRootSignatureDesc));
+      break;
     }
     uint64_t BytesWritten = OS.tell() - DataStart;
     RollingOffset += BytesWritten;
diff --git a/llvm/lib/ObjectYAML/DXContainerYAML.cpp b/llvm/lib/ObjectYAML/DXContainerYAML.cpp
index 5dee1221b27c01..eab3fcc5936f85 100644
--- a/llvm/lib/ObjectYAML/DXContainerYAML.cpp
+++ b/llvm/lib/ObjectYAML/DXContainerYAML.cpp
@@ -13,6 +13,7 @@
 
 #include "llvm/ObjectYAML/DXContainerYAML.h"
 #include "llvm/ADT/ScopeExit.h"
+#include "llvm/Analysis/DXILRootSignature.h"
 #include "llvm/BinaryFormat/DXContainer.h"
 #include "llvm/Support/ScopedPrinter.h"
 
@@ -188,6 +189,12 @@ void MappingTraits<DXContainerYAML::Signature>::mapping(
   IO.mapRequired("Parameters", S.Parameters);
 }
 
+void MappingTraits<DXContainerYAML::RootSignature>::mapping(
+    IO &IO, DXContainerYAML::RootSignature &S) {
+  IO.mapRequired("Version", S.Version);
+  IO.mapRequired("Flags", S.Flags);
+}
+
 void MappingTraits<DXContainerYAML::Part>::mapping(IO &IO,
                                                    DXContainerYAML::Part &P) {
   IO.mapRequired("Name", P.Name);
@@ -197,6 +204,7 @@ void MappingTraits<DXContainerYAML::Part>::mapping(IO &IO,
   IO.mapOptional("Hash", P.Hash);
   IO.mapOptional("PSVInfo", P.Info);
   IO.mapOptional("Signature", P.Signature);
+  IO.mapOptional("RootSignature", P.RootSignature);
 }
 
 void MappingTraits<DXContainerYAML::Object>::mapping(
@@ -290,6 +298,66 @@ void ScalarEnumerationTraits<dxbc::SigComponentType>::enumeration(
     IO.enumCase(Value, E.Name.str().c_str(), E.Value);
 }
 
+template <>
+struct llvm::yaml::ScalarEnumerationTraits<
+    dxil::root_signature::RootSignatureVersion> {
+  static void enumeration(IO &io,
+                          dxil::root_signature::RootSignatureVersion &Val) {
+    io.enumCase(Val, "1.0",
+                dxil::root_signature::RootSignatureVersion::Version_1);
+    io.enumCase(Val, "1.0",
+                dxil::root_signature::RootSignatureVersion::Version_1_0);
+    io.enumCase(Val, "1.1",
+                dxil::root_signature::RootSignatureVersion::Version_1_1);
+    io.enumCase(Val, "1.2",
+                dxil::root_signature::RootSignatureVersion::Version_1_2);
+  }
+};
+
+template <>
+struct llvm::yaml::ScalarEnumerationTraits<
+    dxil::root_signature::RootSignatureFlags> {
+  static void enumeration(IO &io,
+                          dxil::root_signature::RootSignatureFlags &Val) {
+    io.enumCase(Val, "AllowInputAssemblerInputLayout",
+                dxil::root_signature::RootSignatureFlags::
+                    AllowInputAssemblerInputLayout);
+    io.enumCase(
+        Val, "DenyVertexShaderRootAccess",
+        dxil::root_signature::RootSignatureFlags::DenyVertexShaderRootAccess);
+    io.enumCase(
+        Val, "DenyHullShaderRootAccess",
+        dxil::root_signature::RootSignatureFlags::DenyHullShaderRootAccess);
+    io.enumCase(
+        Val, "DenyDomainShaderRootAccess",
+        dxil::root_signature::RootSignatureFlags::DenyDomainShaderRootAccess);
+    io.enumCase(
+        Val, "DenyGeometryShaderRootAccess",
+        dxil::root_signature::RootSignatureFlags::DenyGeometryShaderRootAccess);
+    io.enumCase(
+        Val, "DenyPixelShaderRootAccess",
+        dxil::root_signature::RootSignatureFlags::DenyPixelShaderRootAccess);
+    io.enumCase(Val, "AllowStreamOutput",
+                dxil::root_signature::RootSignatureFlags::AllowStreamOutput);
+    io.enumCase(Val, "LocalRootSignature",
+                dxil::root_signature::RootSignatureFlags::LocalRootSignature);
+    io.enumCase(Val, "DenyAmplificationShaderRootAccess",
+                dxil::root_signature::RootSignatureFlags::
+                    DenyAmplificationShaderRootAccess);
+    io.enumCase(
+        Val, "DenyMeshShaderRootAccess",
+        dxil::root_signature::RootSignatureFlags::DenyMeshShaderRootAccess);
+    io.enumCase(
+        Val, "CBVSRVUAVHeapDirectlyIndexed",
+        dxil::root_signature::RootSignatureFlags::CBVSRVUAVHeapDirectlyIndexed);
+    io.enumCase(
+        Val, "SamplerHeapDirectlyIndexed",
+        dxil::root_signature::RootSignatureFlags::SamplerHeapDirectlyIndexed);
+    io.enumCase(Val, "AllowLowTierReservedHwCbLimit",
+                dxil::root_signature::RootSignatureFlags::
+                    AllowLowTierReservedHwCbLimit);
+  }
+};
 } // namespace yaml
 
 void DXContainerYAML::PSVInfo::mapInfoForVersion(yaml::IO &IO) {
diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
index 7a0bd6a7c88692..886b9f6be41056 100644
--- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
+++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
@@ -17,6 +17,7 @@
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Analysis/DXILMetadataAnalysis.h"
 #include "llvm/Analysis/DXILResource.h"
+#include "llvm/Analysis/DXILRootSignature.h"
 #include "llvm/BinaryFormat/DXContainer.h"
 #include "llvm/CodeGen/Passes.h"
 #include "llvm/IR/Constants.h"
@@ -41,6 +42,7 @@ class DXContainerGlobals : public llvm::ModulePass {
   GlobalVariable *buildSignature(Module &M, Signature &Sig, StringRef Name,
                                  StringRef SectionName);
   void addSignature(Module &M, SmallVector<GlobalValue *> &Globals);
+  void addRootSignature(Module &M, SmallVector<GlobalValue *> &Globals);
   void addResourcesForPSV(Module &M, PSVRuntimeInfo &PSV);
   void addPipelineStateValidationInfo(Module &M,
                                       SmallVector<GlobalValue *> &Globals);
@@ -73,6 +75,7 @@ bool DXContainerGlobals::runOnModule(Module &M) {
   Globals.push_back(getFeatureFlags(M));
   Globals.push_back(computeShaderHash(M));
   addSignature(M, Globals);
+  addRootSignature(M, Globals);
   addPipelineStateValidationInfo(M, Globals);
   appendToCompilerUsed(M, Globals);
   return true;
@@ -144,6 +147,23 @@ void DXContainerGlobals::addSignature(Module &M,
   Globals.emplace_back(buildSignature(M, OutputSig, "dx.osg1", "OSG1"));
 }
 
+void DXContainerGlobals::addRootSignature(Module &M,
+                                          SmallVector<GlobalValue *> &Globals) {
+
+  root_signature::VersionedRootSignatureDesc Desc =
+      getAnalysis<DXILMetadataAnalysisWrapperPass>()
+          .getModuleMetadata()
+          .RootSignatureDesc;
+
+  SmallString<256> Data;
+  raw_svector_ostream OS(Data);
+  OS.write(reinterpret_cast<const char *>(&Desc),
+           sizeof(root_signature::VersionedRootSignatureDesc));
+  Constant *Constant =
+      ConstantDataArray::getString(M.getContext(), Data, /*AddNull*/ false);
+  Globals.emplace_back(buildContainerGlobal(M, Constant, "dx.rts0", "RTS0"));
+}
+
 void DXContainerGlobals::addResourcesForPSV(Module &M, PSVRuntimeInfo &PSV) {
   const DXILBindingMap &DBM =
       getAnalysis<DXILResourceBindingWrapperPass>().getBindingMap();
diff --git a/llvm/tools/obj2yaml/dxcontainer2yaml.cpp b/llvm/tools/obj2yaml/dxcontainer2yaml.cpp
index 06966b1883586c..9fa2612886b41d 100644
--- a/llvm/tools/obj2yaml/dxcontainer2yaml.cpp
+++ b/llvm/tools/obj2yaml/dxcontainer2yaml.cpp
@@ -7,9 +7,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "obj2yaml.h"
+#include "llvm/Analysis/DXILRootSignature.h"
 #include "llvm/Object/DXContainer.h"
 #include "llvm/ObjectYAML/DXContainerYAML.h"
 #include "llvm/Support/Error.h"
+#include "llvm/Support/ErrorHandling.h"
 
 #include <algorithm>
 
@@ -26,6 +28,23 @@ static DXContainerYAML::Signature dumpSignature(const DirectX::Signature &Sig) {
   return YAML;
 }
 
+static DXContainerYAML::RootSignature dumpRootSignature(
+    const dxil::root_signature::VersionedRootSignatureDesc &Desc) {
+  DXContainerYAML::RootSignature YAML;
+  YAML.Version = Desc.Version;
+
+  switch (Desc.Version) {
+  case dxil::root_signature::RootSignatureVersion::Version_1:
+    YAML.Flags = Desc.Desc_1_0.Flags;
+    break;
+  case dxil::root_signature::RootSignatureVersion::Version_1_1:
+  case dxil::root_signature::RootSignatureVersion::Version_1_2:
+    llvm_unreachable("Not Implemented yet");
+    break;
+  }
+  return YAML;
+}
+
 static Expected<DXContainerYAML::Object *>
 dumpDXContainer(MemoryBufferRef Source) {
   assert(file_magic::dxcontainer_object == identify_magic(Source.getBuffer()));
@@ -153,6 +172,9 @@ dumpDXContainer(MemoryBufferRef Source) {
       break;
     case dxbc::PartType::Unknown:
       break;
+    case dxbc::PartType::RTS0:
+      NewPart.RootSignature = dumpRootSignature(*Container.getRootSignature());
+      break;
     }
   }
 



More information about the llvm-commits mailing list