[llvm] [NFC][DirectX] Infrastructure to collect shader flags for each function (PR #112967)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Oct 18 13:01:10 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-directx
Author: S. Bharadwaj Yadavalli (bharadwajy)
<details>
<summary>Changes</summary>
Currently, ShaderFlagsAnalysis pass represents various module-level properties as well as function-level properties of a DXIL Module using a single mask. However, separate flags to represent module-level properties and function-level properties are needed for accurate computation of shader flags mask, such as for entry function metadata creation.
This change introduces a structure that allows separate representation of
(a) shader flag mask to represent module properties
(b) a map of function to shader flag mask that represent function properties
instead of a single shader flag mask that represents module properties and properties of all function. The result type of ShaderFlagsAnalysis pass is changed to newly-defined structure type instead of a single shader flags mask.
This seperation allows accurate computation of shader flags of an entry function for use during its metadata generation (DXILTranslateMetadata pass) and its feature flags in DX container globals construction (DXContainerGlobals pass) based on the shader flags mask of functions called in entry function. However, note that the change to implement such callee-based shader flags mask computation is planned in a follow-on PR. Consequently, this PR changes shader flag mask computation in DXILTranslateMetadata and DXContainerGlobals passes to simply be a union of module flags and shader flags of all functions, thereby retaining the existing effect of using a single shader flag mask.
---
Full diff: https://github.com/llvm/llvm-project/pull/112967.diff
4 Files Affected:
- (modified) llvm/lib/Target/DirectX/DXContainerGlobals.cpp (+10-5)
- (modified) llvm/lib/Target/DirectX/DXILShaderFlags.cpp (+32-14)
- (modified) llvm/lib/Target/DirectX/DXILShaderFlags.h (+17-9)
- (modified) llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp (+28-18)
``````````diff
diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
index 2c11373504e8c7..c7202cc04c26dc 100644
--- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
+++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
@@ -78,13 +78,18 @@ bool DXContainerGlobals::runOnModule(Module &M) {
}
GlobalVariable *DXContainerGlobals::getFeatureFlags(Module &M) {
- const uint64_t FeatureFlags =
- static_cast<uint64_t>(getAnalysis<ShaderFlagsAnalysisWrapper>()
- .getShaderFlags()
- .getFeatureFlags());
+ const DXILModuleShaderFlagsInfo &MSFI =
+ getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags();
+ // TODO: Feature flags mask is obtained as a collection of feature flags
+ // of the shader flags of all functions in the module. Need to verify
+ // and modify the computation of feature flags to be used.
+ uint64_t ConsolidatedFeatureFlags = 0;
+ for (const auto &FuncFlags : MSFI.FuncShaderFlagsMap) {
+ ConsolidatedFeatureFlags |= FuncFlags.second.getFeatureFlags();
+ }
Constant *FeatureFlagsConstant =
- ConstantInt::get(M.getContext(), APInt(64, FeatureFlags));
+ ConstantInt::get(M.getContext(), APInt(64, ConsolidatedFeatureFlags));
return buildContainerGlobal(M, FeatureFlagsConstant, "dx.sfi0", "SFI0");
}
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index 9fa137b4c025e1..8c590862008862 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -20,33 +20,41 @@
using namespace llvm;
using namespace llvm::dxil;
-static void updateFlags(ComputedShaderFlags &Flags, const Instruction &I) {
+static void updateFlags(DXILModuleShaderFlagsInfo &MSFI, const Instruction &I) {
+ ComputedShaderFlags &FSF = MSFI.FuncShaderFlagsMap[I.getFunction()];
Type *Ty = I.getType();
if (Ty->isDoubleTy()) {
- Flags.Doubles = true;
+ FSF.Doubles = true;
switch (I.getOpcode()) {
case Instruction::FDiv:
case Instruction::UIToFP:
case Instruction::SIToFP:
case Instruction::FPToUI:
case Instruction::FPToSI:
- Flags.DX11_1_DoubleExtensions = true;
+ FSF.DX11_1_DoubleExtensions = true;
break;
}
}
}
-ComputedShaderFlags ComputedShaderFlags::computeFlags(Module &M) {
- ComputedShaderFlags Flags;
- for (const auto &F : M)
+static DXILModuleShaderFlagsInfo computeFlags(Module &M) {
+ DXILModuleShaderFlagsInfo MSFI;
+ for (const auto &F : M) {
+ if (F.isDeclaration())
+ continue;
+ if (!MSFI.FuncShaderFlagsMap.contains(&F)) {
+ ComputedShaderFlags CSF{};
+ MSFI.FuncShaderFlagsMap[&F] = CSF;
+ }
for (const auto &BB : F)
for (const auto &I : BB)
- updateFlags(Flags, I);
- return Flags;
+ updateFlags(MSFI, I);
+ }
+ return MSFI;
}
void ComputedShaderFlags::print(raw_ostream &OS) const {
- uint64_t FlagVal = (uint64_t) * this;
+ uint64_t FlagVal = (uint64_t)*this;
OS << formatv("; Shader Flags Value: {0:x8}\n;\n", FlagVal);
if (FlagVal == 0)
return;
@@ -65,15 +73,25 @@ void ComputedShaderFlags::print(raw_ostream &OS) const {
AnalysisKey ShaderFlagsAnalysis::Key;
-ComputedShaderFlags ShaderFlagsAnalysis::run(Module &M,
- ModuleAnalysisManager &AM) {
- return ComputedShaderFlags::computeFlags(M);
+DXILModuleShaderFlagsInfo ShaderFlagsAnalysis::run(Module &M,
+ ModuleAnalysisManager &AM) {
+ return computeFlags(M);
+}
+
+bool ShaderFlagsAnalysisWrapper::runOnModule(Module &M) {
+ MSFI = computeFlags(M);
+ return false;
}
PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
ModuleAnalysisManager &AM) {
- ComputedShaderFlags Flags = AM.getResult<ShaderFlagsAnalysis>(M);
- Flags.print(OS);
+ DXILModuleShaderFlagsInfo Flags = AM.getResult<ShaderFlagsAnalysis>(M);
+ OS << "; Shader Flags mask for Module:\n";
+ Flags.ModuleFlags.print(OS);
+ for (auto SF : Flags.FuncShaderFlagsMap) {
+ OS << "; Shader Flags mash for Function: " << SF.first->getName() << "\n";
+ SF.second.print(OS);
+ }
return PreservedAnalyses::all();
}
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.h b/llvm/lib/Target/DirectX/DXILShaderFlags.h
index 1df7d27de13d3c..6f81ff74384d0c 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.h
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.h
@@ -14,6 +14,8 @@
#ifndef LLVM_TARGET_DIRECTX_DXILSHADERFLAGS_H
#define LLVM_TARGET_DIRECTX_DXILSHADERFLAGS_H
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/IR/Function.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Pass.h"
#include "llvm/Support/Compiler.h"
@@ -60,11 +62,20 @@ struct ComputedShaderFlags {
return FeatureFlags;
}
- static ComputedShaderFlags computeFlags(Module &M);
void print(raw_ostream &OS = dbgs()) const;
LLVM_DUMP_METHOD void dump() const { print(); }
};
+using FunctionShaderFlagsMap =
+ SmallDenseMap<Function const *, ComputedShaderFlags>;
+struct DXILModuleShaderFlagsInfo {
+ // Shader Flag mask representing module-level properties
+ ComputedShaderFlags ModuleFlags;
+ // Map representing shader flag mask representing properties of each of the
+ // functions in the module
+ FunctionShaderFlagsMap FuncShaderFlagsMap;
+};
+
class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> {
friend AnalysisInfoMixin<ShaderFlagsAnalysis>;
static AnalysisKey Key;
@@ -72,9 +83,9 @@ class ShaderFlagsAnalysis : public AnalysisInfoMixin<ShaderFlagsAnalysis> {
public:
ShaderFlagsAnalysis() = default;
- using Result = ComputedShaderFlags;
+ using Result = DXILModuleShaderFlagsInfo;
- ComputedShaderFlags run(Module &M, ModuleAnalysisManager &AM);
+ DXILModuleShaderFlagsInfo run(Module &M, ModuleAnalysisManager &AM);
};
/// Printer pass for ShaderFlagsAnalysis results.
@@ -92,19 +103,16 @@ class ShaderFlagsAnalysisPrinter
/// This is required because the passes that will depend on this are codegen
/// passes which run through the legacy pass manager.
class ShaderFlagsAnalysisWrapper : public ModulePass {
- ComputedShaderFlags Flags;
+ DXILModuleShaderFlagsInfo MSFI;
public:
static char ID;
ShaderFlagsAnalysisWrapper() : ModulePass(ID) {}
- const ComputedShaderFlags &getShaderFlags() { return Flags; }
+ const DXILModuleShaderFlagsInfo &getShaderFlags() { return MSFI; }
- bool runOnModule(Module &M) override {
- Flags = ComputedShaderFlags::computeFlags(M);
- return false;
- }
+ bool runOnModule(Module &M) override;
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesAll();
diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index be370e10df6943..2da4fe83a066c2 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -286,11 +286,6 @@ static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD,
MDTuple *Properties = nullptr;
if (ShaderFlags != 0) {
SmallVector<Metadata *> MDVals;
- // FIXME: ShaderFlagsAnalysis pass needs to collect and provide
- // ShaderFlags for each entry function. Currently, ShaderFlags value
- // provided by ShaderFlagsAnalysis pass is created by walking *all* the
- // function instructions of the module. Is it is correct to use this value
- // for metadata of the empty library entry?
MDVals.append(
getTagValueAsMetadata(EntryPropsTag::ShaderFlags, ShaderFlags, Ctx));
Properties = MDNode::get(Ctx, MDVals);
@@ -302,7 +297,7 @@ static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD,
static void translateMetadata(Module &M, const DXILResourceMap &DRM,
const Resources &MDResources,
- const ComputedShaderFlags &ShaderFlags,
+ const DXILModuleShaderFlagsInfo &ShaderFlags,
const ModuleMetadataInfo &MMDI) {
LLVMContext &Ctx = M.getContext();
IRBuilder<> IRB(Ctx);
@@ -318,22 +313,37 @@ static void translateMetadata(Module &M, const DXILResourceMap &DRM,
// See https://github.com/llvm/llvm-project/issues/57928
MDTuple *Signatures = nullptr;
- if (MMDI.ShaderProfile == Triple::EnvironmentType::Library)
+ if (MMDI.ShaderProfile == Triple::EnvironmentType::Library) {
+ // Create a consolidated shader flag mask of all functions in the library
+ // to be used as shader flags mask value associated with top-level library
+ // entry metadata.
+ uint64_t ConsolidatedMask = ShaderFlags.ModuleFlags;
+ for (const auto &FunFlags : ShaderFlags.FuncShaderFlagsMap) {
+ ConsolidatedMask |= FunFlags.second;
+ }
EntryFnMDNodes.emplace_back(
- emitTopLevelLibraryNode(M, ResourceMD, ShaderFlags));
- else if (MMDI.EntryPropertyVec.size() > 1) {
+ emitTopLevelLibraryNode(M, ResourceMD, ConsolidatedMask));
+ } else if (MMDI.EntryPropertyVec.size() > 1) {
M.getContext().diagnose(DiagnosticInfoTranslateMD(
M, "Non-library shader: One and only one entry expected"));
}
for (const EntryProperties &EntryProp : MMDI.EntryPropertyVec) {
- // FIXME: ShaderFlagsAnalysis pass needs to collect and provide
- // ShaderFlags for each entry function. For now, assume shader flags value
- // of entry functions being compiled for lib_* shader profile viz.,
- // EntryPro.Entry is 0.
- uint64_t EntryShaderFlags =
- (MMDI.ShaderProfile == Triple::EnvironmentType::Library) ? 0
- : ShaderFlags;
+ auto FSFIt = ShaderFlags.FuncShaderFlagsMap.find(EntryProp.Entry);
+ if (FSFIt == ShaderFlags.FuncShaderFlagsMap.end()) {
+ M.getContext().diagnose(DiagnosticInfoTranslateMD(
+ M, "Shader Flags of Function '" + Twine(EntryProp.Entry->getName()) +
+ "' not found"));
+ }
+ // If ShaderProfile is Library, mask is already consolidated in the
+ // top-level library node. Hence it is not emitted.
+ uint64_t EntryShaderFlags = 0;
+ if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
+ // TODO: Create a consolidated shader flag mask of all the entry
+ // functions and its callees. The following is correct only if
+ // (*FSIt).first has no call instructions.
+ EntryShaderFlags = (*FSFIt).second | ShaderFlags.ModuleFlags;
+ }
if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
if (EntryProp.ShaderStage != MMDI.ShaderProfile) {
M.getContext().diagnose(DiagnosticInfoTranslateMD(
@@ -361,7 +371,7 @@ PreservedAnalyses DXILTranslateMetadata::run(Module &M,
ModuleAnalysisManager &MAM) {
const DXILResourceMap &DRM = MAM.getResult<DXILResourceAnalysis>(M);
const dxil::Resources &MDResources = MAM.getResult<DXILResourceMDAnalysis>(M);
- const ComputedShaderFlags &ShaderFlags =
+ const DXILModuleShaderFlagsInfo &ShaderFlags =
MAM.getResult<ShaderFlagsAnalysis>(M);
const dxil::ModuleMetadataInfo MMDI = MAM.getResult<DXILMetadataAnalysis>(M);
@@ -393,7 +403,7 @@ class DXILTranslateMetadataLegacy : public ModulePass {
getAnalysis<DXILResourceWrapperPass>().getResourceMap();
const dxil::Resources &MDResources =
getAnalysis<DXILResourceMDWrapper>().getDXILResource();
- const ComputedShaderFlags &ShaderFlags =
+ const DXILModuleShaderFlagsInfo &ShaderFlags =
getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags();
dxil::ModuleMetadataInfo MMDI =
getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
``````````
</details>
https://github.com/llvm/llvm-project/pull/112967
More information about the llvm-commits
mailing list