[llvm] [DirectX] Get resource information via TargetExtType (PR #119772)
Justin Bogner via llvm-commits
llvm-commits at lists.llvm.org
Mon Dec 16 10:32:45 PST 2024
https://github.com/bogner updated https://github.com/llvm/llvm-project/pull/119772
>From 9f5564c72ada73feb88bcb82fcdfcc7ca4c4cada Mon Sep 17 00:00:00 2001
From: Justin Bogner <mail at justinbogner.com>
Date: Thu, 12 Dec 2024 13:26:07 -0800
Subject: [PATCH 1/3] [DirectX] Get resource information via TargetExtType
Instead of storing an auxilliary structure with the information from the
DXIL resource target extension types duplicated, access the information
that we can via the type itself.
This also means we need to handle some of the target extension types we
haven't fully defined yet, like Texture and CBuffer. For now we make an
educated guess to what those should look like based on llvm/wg-hlsl#76,
and we can update them fairly easily when we've defined them more
thoroughly.
First part of #118400
---
llvm/include/llvm/Analysis/DXILResource.h | 384 ++++++----
llvm/lib/Analysis/DXILResource.cpp | 719 ++++++++----------
llvm/lib/Target/DirectX/DXILOpLowering.cpp | 2 +-
llvm/lib/Target/DirectX/DXILPrettyPrinter.cpp | 2 +-
.../Target/DirectX/DXILTranslateMetadata.cpp | 8 +-
llvm/unittests/Analysis/DXILResourceTest.cpp | 395 ++++++----
6 files changed, 795 insertions(+), 715 deletions(-)
diff --git a/llvm/include/llvm/Analysis/DXILResource.h b/llvm/include/llvm/Analysis/DXILResource.h
index 6b577c02f05450..0205356af54443 100644
--- a/llvm/include/llvm/Analysis/DXILResource.h
+++ b/llvm/include/llvm/Analysis/DXILResource.h
@@ -11,6 +11,7 @@
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Pass.h"
#include "llvm/Support/Alignment.h"
@@ -18,12 +19,183 @@
namespace llvm {
class CallInst;
+class DataLayout;
class LLVMContext;
class MDTuple;
+class TargetExtType;
class Value;
namespace dxil {
+/// The dx.RawBuffer target extension type
+///
+/// `target("dx.RawBuffer", Type, IsWriteable, IsROV)`
+class RawBufferExtType : public TargetExtType {
+public:
+ RawBufferExtType() = delete;
+ RawBufferExtType(const RawBufferExtType &) = delete;
+ RawBufferExtType &operator=(const RawBufferExtType &) = delete;
+
+ bool isStructured() const {
+ // TODO: We need to be more prescriptive here, but since there's some debate
+ // over whether byte address buffer should have a void type or an i8 type,
+ // accept either for now.
+ Type *Ty = getTypeParameter(0);
+ return !Ty->isVoidTy() && !Ty->isIntegerTy(8);
+ }
+
+ Type *getResourceType() const {
+ return isStructured() ? getTypeParameter(0) : nullptr;
+ }
+ bool isWriteable() const { return getIntParameter(0); }
+ bool isROV() const { return getIntParameter(1); }
+
+ static bool classof(const TargetExtType *T) {
+ return T->getName() == "dx.RawBuffer";
+ }
+ static bool classof(const Type *T) {
+ return isa<TargetExtType>(T) && classof(cast<TargetExtType>(T));
+ }
+};
+
+/// The dx.TypedBuffer target extension type
+///
+/// `target("dx.TypedBuffer", Type, IsWriteable, IsROV, IsSigned)`
+class TypedBufferExtType : public TargetExtType {
+public:
+ TypedBufferExtType() = delete;
+ TypedBufferExtType(const TypedBufferExtType &) = delete;
+ TypedBufferExtType &operator=(const TypedBufferExtType &) = delete;
+
+ Type *getResourceType() const { return getTypeParameter(0); }
+ bool isWriteable() const { return getIntParameter(0); }
+ bool isROV() const { return getIntParameter(1); }
+ bool isSigned() const { return getIntParameter(2); }
+
+ static bool classof(const TargetExtType *T) {
+ return T->getName() == "dx.TypedBuffer";
+ }
+ static bool classof(const Type *T) {
+ return isa<TargetExtType>(T) && classof(cast<TargetExtType>(T));
+ }
+};
+
+/// The dx.Texture target extension type
+///
+/// `target("dx.Texture", Type, IsWriteable, IsROV, IsSigned, Dimension)`
+class TextureExtType : public TargetExtType {
+public:
+ TextureExtType() = delete;
+ TextureExtType(const TextureExtType &) = delete;
+ TextureExtType &operator=(const TextureExtType &) = delete;
+
+ Type *getResourceType() const { return getTypeParameter(0); }
+ bool isWriteable() const { return getIntParameter(0); }
+ bool isROV() const { return getIntParameter(1); }
+ bool isSigned() const { return getIntParameter(2); }
+ dxil::ResourceKind getDimension() const {
+ return static_cast<dxil::ResourceKind>(getIntParameter(3));
+ }
+
+ static bool classof(const TargetExtType *T) {
+ return T->getName() == "dx.Texture";
+ }
+ static bool classof(const Type *T) {
+ return isa<TargetExtType>(T) && classof(cast<TargetExtType>(T));
+ }
+};
+
+/// The dx.MSTexture target extension type
+///
+/// `target("dx.MSTexture", Type, IsWriteable, Samples, IsSigned, Dimension)`
+class MSTextureExtType : public TargetExtType {
+public:
+ MSTextureExtType() = delete;
+ MSTextureExtType(const MSTextureExtType &) = delete;
+ MSTextureExtType &operator=(const MSTextureExtType &) = delete;
+
+ Type *getResourceType() const { return getTypeParameter(0); }
+ bool isWriteable() const { return getIntParameter(0); }
+ uint32_t getSampleCount() const { return getIntParameter(1); }
+ bool isSigned() const { return getIntParameter(2); }
+ dxil::ResourceKind getDimension() const {
+ return static_cast<dxil::ResourceKind>(getIntParameter(3));
+ }
+
+ static bool classof(const TargetExtType *T) {
+ return T->getName() == "dx.MSTexture";
+ }
+ static bool classof(const Type *T) {
+ return isa<TargetExtType>(T) && classof(cast<TargetExtType>(T));
+ }
+};
+
+/// The dx.FeedbackTexture target extension type
+///
+/// `target("dx.FeedbackTexture", FeedbackType, Dimension)`
+class FeedbackTextureExtType : public TargetExtType {
+public:
+ FeedbackTextureExtType() = delete;
+ FeedbackTextureExtType(const FeedbackTextureExtType &) = delete;
+ FeedbackTextureExtType &operator=(const FeedbackTextureExtType &) = delete;
+
+ dxil::SamplerFeedbackType getFeedbackType() const {
+ return static_cast<dxil::SamplerFeedbackType>(getIntParameter(0));
+ }
+ dxil::ResourceKind getDimension() const {
+ return static_cast<dxil::ResourceKind>(getIntParameter(1));
+ }
+
+ static bool classof(const TargetExtType *T) {
+ return T->getName() == "dx.FeedbackTexture";
+ }
+ static bool classof(const Type *T) {
+ return isa<TargetExtType>(T) && classof(cast<TargetExtType>(T));
+ }
+};
+
+/// The dx.CBuffer target extension type
+///
+/// `target("dx.CBuffer", <Type>, ...)`
+class CBufferExtType : public TargetExtType {
+public:
+ CBufferExtType() = delete;
+ CBufferExtType(const CBufferExtType &) = delete;
+ CBufferExtType &operator=(const CBufferExtType &) = delete;
+
+ Type *getResourceType() const { return getTypeParameter(0); }
+
+ static bool classof(const TargetExtType *T) {
+ return T->getName() == "dx.CBuffer";
+ }
+ static bool classof(const Type *T) {
+ return isa<TargetExtType>(T) && classof(cast<TargetExtType>(T));
+ }
+};
+
+/// The dx.Sampler target extension type
+///
+/// `target("dx.Sampler", SamplerType)`
+class SamplerExtType : public TargetExtType {
+public:
+ SamplerExtType() = delete;
+ SamplerExtType(const SamplerExtType &) = delete;
+ SamplerExtType &operator=(const SamplerExtType &) = delete;
+
+ dxil::SamplerType getSamplerType() const {
+ return static_cast<dxil::SamplerType>(getIntParameter(0));
+ }
+
+ static bool classof(const TargetExtType *T) {
+ return T->getName() == "dx.Sampler";
+ }
+ static bool classof(const Type *T) {
+ return isa<TargetExtType>(T) && classof(cast<TargetExtType>(T));
+ }
+};
+
+//===----------------------------------------------------------------------===//
+
class ResourceInfo {
public:
struct ResourceBinding {
@@ -93,55 +265,27 @@ class ResourceInfo {
}
};
- struct MSInfo {
- uint32_t Count;
-
- bool operator==(const MSInfo &RHS) const { return Count == RHS.Count; }
- bool operator!=(const MSInfo &RHS) const { return !(*this == RHS); }
- bool operator<(const MSInfo &RHS) const { return Count < RHS.Count; }
- };
-
- struct FeedbackInfo {
- dxil::SamplerFeedbackType Type;
-
- bool operator==(const FeedbackInfo &RHS) const { return Type == RHS.Type; }
- bool operator!=(const FeedbackInfo &RHS) const { return !(*this == RHS); }
- bool operator<(const FeedbackInfo &RHS) const { return Type < RHS.Type; }
- };
-
private:
- // Universal properties.
- Value *Symbol;
- StringRef Name;
+ ResourceBinding Binding;
+ TargetExtType *HandleTy;
+
+ // GloballyCoherent and HasCounter aren't really part of the type and need to
+ // be determined by analysis, so they're just provided directly when we
+ // construct these.
+ bool GloballyCoherent;
+ bool HasCounter;
dxil::ResourceClass RC;
dxil::ResourceKind Kind;
- ResourceBinding Binding = {};
-
- // Resource class dependent properties.
- // CBuffer, Sampler, and RawBuffer end here.
- union {
- UAVInfo UAVFlags; // UAV
- uint32_t CBufferSize; // CBuffer
- dxil::SamplerType SamplerTy; // Sampler
- };
-
- // Resource kind dependent properties.
- union {
- StructInfo Struct; // StructuredBuffer
- TypedInfo Typed; // All SRV/UAV except Raw/StructuredBuffer
- FeedbackInfo Feedback; // FeedbackTexture
- };
-
- MSInfo MultiSample;
-
public:
- ResourceInfo(dxil::ResourceClass RC, dxil::ResourceKind Kind, Value *Symbol,
- StringRef Name)
- : Symbol(Symbol), Name(Name), RC(RC), Kind(Kind) {}
+ ResourceInfo(uint32_t RecordID, uint32_t Space, uint32_t LowerBound,
+ uint32_t Size, TargetExtType *HandleTy,
+ bool GloballyCoherent = false, bool HasCounter = false);
+
+ TargetExtType *getHandleTy() const { return HandleTy; }
- // Conditions to check before accessing union members.
+ // Conditions to check before accessing specific views.
bool isUAV() const;
bool isCBuffer() const;
bool isSampler() const;
@@ -150,148 +294,69 @@ class ResourceInfo {
bool isFeedback() const;
bool isMultiSample() const;
- void bind(uint32_t RecordID, uint32_t Space, uint32_t LowerBound,
- uint32_t Size) {
- Binding.RecordID = RecordID;
- Binding.Space = Space;
- Binding.LowerBound = LowerBound;
- Binding.Size = Size;
- }
- const ResourceBinding &getBinding() const { return Binding; }
- void setUAV(bool GloballyCoherent, bool HasCounter, bool IsROV) {
- assert(isUAV() && "Not a UAV");
- UAVFlags.GloballyCoherent = GloballyCoherent;
- UAVFlags.HasCounter = HasCounter;
- UAVFlags.IsROV = IsROV;
- }
- const UAVInfo &getUAV() const {
- assert(isUAV() && "Not a UAV");
- return UAVFlags;
- }
- void setCBuffer(uint32_t Size) {
- assert(isCBuffer() && "Not a CBuffer");
- CBufferSize = Size;
- }
- void setSampler(dxil::SamplerType Ty) { SamplerTy = Ty; }
- void setStruct(uint32_t Stride, MaybeAlign Alignment) {
- assert(isStruct() && "Not a Struct");
- Struct.Stride = Stride;
- Struct.AlignLog2 = Alignment ? Log2(*Alignment) : 0;
- }
- void setTyped(dxil::ElementType ElementTy, uint32_t ElementCount) {
- assert(isTyped() && "Not Typed");
- Typed.ElementTy = ElementTy;
- Typed.ElementCount = ElementCount;
- }
- const TypedInfo &getTyped() const {
- assert(isTyped() && "Not typed");
- return Typed;
- }
- void setFeedback(dxil::SamplerFeedbackType Type) {
- assert(isFeedback() && "Not Feedback");
- Feedback.Type = Type;
- }
- void setMultiSample(uint32_t Count) {
- assert(isMultiSample() && "Not MultiSampled");
- MultiSample.Count = Count;
+ // Views into the type.
+ UAVInfo getUAV() const;
+ uint32_t getCBufferSize(const DataLayout &DL) const;
+ dxil::SamplerType getSamplerType() const;
+ StructInfo getStruct(const DataLayout &DL) const;
+ TypedInfo getTyped() const;
+ dxil::SamplerFeedbackType getFeedbackType() const;
+ uint32_t getMultiSampleCount() const;
+
+ StringRef getName() const {
+ // TODO: Get the name from the symbol once we include one here.
+ return "";
}
- const MSInfo &getMultiSample() const {
- assert(isMultiSample() && "Not MultiSampled");
- return MultiSample;
- }
-
- StringRef getName() const { return Name; }
dxil::ResourceClass getResourceClass() const { return RC; }
dxil::ResourceKind getResourceKind() const { return Kind; }
+ void setBindingID(unsigned ID) { Binding.RecordID = ID; }
+
+ const ResourceBinding &getBinding() const { return Binding; }
+
+ MDTuple *getAsMetadata(Module &M) const;
+ std::pair<uint32_t, uint32_t> getAnnotateProps(Module &M) const;
+
bool operator==(const ResourceInfo &RHS) const;
bool operator!=(const ResourceInfo &RHS) const { return !(*this == RHS); }
bool operator<(const ResourceInfo &RHS) const;
- static ResourceInfo SRV(Value *Symbol, StringRef Name,
- dxil::ElementType ElementTy, uint32_t ElementCount,
- dxil::ResourceKind Kind);
- static ResourceInfo RawBuffer(Value *Symbol, StringRef Name);
- static ResourceInfo StructuredBuffer(Value *Symbol, StringRef Name,
- uint32_t Stride, MaybeAlign Alignment);
- static ResourceInfo Texture2DMS(Value *Symbol, StringRef Name,
- dxil::ElementType ElementTy,
- uint32_t ElementCount, uint32_t SampleCount);
- static ResourceInfo Texture2DMSArray(Value *Symbol, StringRef Name,
- dxil::ElementType ElementTy,
- uint32_t ElementCount,
- uint32_t SampleCount);
-
- static ResourceInfo UAV(Value *Symbol, StringRef Name,
- dxil::ElementType ElementTy, uint32_t ElementCount,
- bool GloballyCoherent, bool IsROV,
- dxil::ResourceKind Kind);
- static ResourceInfo RWRawBuffer(Value *Symbol, StringRef Name,
- bool GloballyCoherent, bool IsROV);
- static ResourceInfo RWStructuredBuffer(Value *Symbol, StringRef Name,
- uint32_t Stride, MaybeAlign Alignment,
- bool GloballyCoherent, bool IsROV,
- bool HasCounter);
- static ResourceInfo RWTexture2DMS(Value *Symbol, StringRef Name,
- dxil::ElementType ElementTy,
- uint32_t ElementCount, uint32_t SampleCount,
- bool GloballyCoherent);
- static ResourceInfo RWTexture2DMSArray(Value *Symbol, StringRef Name,
- dxil::ElementType ElementTy,
- uint32_t ElementCount,
- uint32_t SampleCount,
- bool GloballyCoherent);
- static ResourceInfo FeedbackTexture2D(Value *Symbol, StringRef Name,
- dxil::SamplerFeedbackType FeedbackTy);
- static ResourceInfo
- FeedbackTexture2DArray(Value *Symbol, StringRef Name,
- dxil::SamplerFeedbackType FeedbackTy);
-
- static ResourceInfo CBuffer(Value *Symbol, StringRef Name, uint32_t Size);
-
- static ResourceInfo Sampler(Value *Symbol, StringRef Name,
- dxil::SamplerType SamplerTy);
-
- MDTuple *getAsMetadata(LLVMContext &Ctx) const;
-
- std::pair<uint32_t, uint32_t> getAnnotateProps() const;
-
- void print(raw_ostream &OS) const;
+ void print(raw_ostream &OS, const DataLayout &DL) const;
};
} // namespace dxil
+//===----------------------------------------------------------------------===//
+
class DXILResourceMap {
- SmallVector<dxil::ResourceInfo> Resources;
+ SmallVector<dxil::ResourceInfo> Infos;
DenseMap<CallInst *, unsigned> CallMap;
unsigned FirstUAV = 0;
unsigned FirstCBuffer = 0;
unsigned FirstSampler = 0;
+ /// Populate the map given the resource binding calls in the given module.
+ void populate(Module &M);
+
public:
using iterator = SmallVector<dxil::ResourceInfo>::iterator;
using const_iterator = SmallVector<dxil::ResourceInfo>::const_iterator;
- DXILResourceMap(
- SmallVectorImpl<std::pair<CallInst *, dxil::ResourceInfo>> &&CIToRI);
+ iterator begin() { return Infos.begin(); }
+ const_iterator begin() const { return Infos.begin(); }
+ iterator end() { return Infos.end(); }
+ const_iterator end() const { return Infos.end(); }
- iterator begin() { return Resources.begin(); }
- const_iterator begin() const { return Resources.begin(); }
- iterator end() { return Resources.end(); }
- const_iterator end() const { return Resources.end(); }
-
- bool empty() const { return Resources.empty(); }
+ bool empty() const { return Infos.empty(); }
iterator find(const CallInst *Key) {
auto Pos = CallMap.find(Key);
- return Pos == CallMap.end() ? Resources.end()
- : (Resources.begin() + Pos->second);
+ return Pos == CallMap.end() ? Infos.end() : (Infos.begin() + Pos->second);
}
const_iterator find(const CallInst *Key) const {
auto Pos = CallMap.find(Key);
- return Pos == CallMap.end() ? Resources.end()
- : (Resources.begin() + Pos->second);
+ return Pos == CallMap.end() ? Infos.end() : (Infos.begin() + Pos->second);
}
iterator srv_begin() { return begin(); }
@@ -334,7 +399,10 @@ class DXILResourceMap {
return make_range(sampler_begin(), sampler_end());
}
- void print(raw_ostream &OS) const;
+ void print(raw_ostream &OS, const DataLayout &DL) const;
+
+ friend class DXILResourceAnalysis;
+ friend class DXILResourceWrapperPass;
};
class DXILResourceAnalysis : public AnalysisInfoMixin<DXILResourceAnalysis> {
@@ -362,7 +430,7 @@ class DXILResourcePrinterPass : public PassInfoMixin<DXILResourcePrinterPass> {
};
class DXILResourceWrapperPass : public ModulePass {
- std::unique_ptr<DXILResourceMap> ResourceMap;
+ std::unique_ptr<DXILResourceMap> Map;
public:
static char ID; // Class identification, replacement for typeinfo
@@ -370,8 +438,8 @@ class DXILResourceWrapperPass : public ModulePass {
DXILResourceWrapperPass();
~DXILResourceWrapperPass() override;
- const DXILResourceMap &getResourceMap() const { return *ResourceMap; }
- DXILResourceMap &getResourceMap() { return *ResourceMap; }
+ const DXILResourceMap &getResourceMap() const { return *Map; }
+ DXILResourceMap &getResourceMap() { return *Map; }
void getAnalysisUsage(AnalysisUsage &AU) const override;
bool runOnModule(Module &M) override;
diff --git a/llvm/lib/Analysis/DXILResource.cpp b/llvm/lib/Analysis/DXILResource.cpp
index 2802480481690d..f96a9468d6bc54 100644
--- a/llvm/lib/Analysis/DXILResource.cpp
+++ b/llvm/lib/Analysis/DXILResource.cpp
@@ -8,6 +8,7 @@
#include "llvm/Analysis/DXILResource.h"
#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/SmallString.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/DiagnosticInfo.h"
@@ -17,6 +18,7 @@
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"
#include "llvm/InitializePasses.h"
+#include "llvm/Support/FormatVariadic.h"
#define DEBUG_TYPE "dxil-resource"
@@ -148,11 +150,74 @@ static StringRef getSamplerFeedbackTypeName(SamplerFeedbackType SFT) {
llvm_unreachable("Unhandled SamplerFeedbackType");
}
+static dxil::ElementType toDXILElementType(Type *Ty, bool IsSigned) {
+ // TODO: Handle unorm, snorm, and packed.
+ Ty = Ty->getScalarType();
+
+ if (Ty->isIntegerTy()) {
+ switch (Ty->getIntegerBitWidth()) {
+ case 16:
+ return IsSigned ? ElementType::I16 : ElementType::U16;
+ case 32:
+ return IsSigned ? ElementType::I32 : ElementType::U32;
+ case 64:
+ return IsSigned ? ElementType::I64 : ElementType::U64;
+ case 1:
+ default:
+ return ElementType::Invalid;
+ }
+ } else if (Ty->isFloatTy()) {
+ return ElementType::F32;
+ } else if (Ty->isDoubleTy()) {
+ return ElementType::F64;
+ } else if (Ty->isHalfTy()) {
+ return ElementType::F16;
+ }
+
+ return ElementType::Invalid;
+}
+
+ResourceInfo::ResourceInfo(uint32_t RecordID, uint32_t Space,
+ uint32_t LowerBound, uint32_t Size,
+ TargetExtType *HandleTy, bool GloballyCoherent,
+ bool HasCounter)
+ : Binding{RecordID, Space, LowerBound, Size}, HandleTy(HandleTy),
+ GloballyCoherent(GloballyCoherent), HasCounter(HasCounter) {
+ if (auto *Ty = dyn_cast<RawBufferExtType>(HandleTy)) {
+ RC = Ty->isWriteable() ? ResourceClass::UAV : ResourceClass::SRV;
+ Kind = Ty->isStructured() ? ResourceKind::StructuredBuffer
+ : ResourceKind::RawBuffer;
+ } else if (auto *Ty = dyn_cast<TypedBufferExtType>(HandleTy)) {
+ RC = Ty->isWriteable() ? ResourceClass::UAV : ResourceClass::SRV;
+ Kind = ResourceKind::TypedBuffer;
+ } else if (auto *Ty = dyn_cast<TextureExtType>(HandleTy)) {
+ RC = Ty->isWriteable() ? ResourceClass::UAV : ResourceClass::SRV;
+ Kind = Ty->getDimension();
+ } else if (auto *Ty = dyn_cast<MSTextureExtType>(HandleTy)) {
+ RC = Ty->isWriteable() ? ResourceClass::UAV : ResourceClass::SRV;
+ Kind = Ty->getDimension();
+ } else if (auto *Ty = dyn_cast<FeedbackTextureExtType>(HandleTy)) {
+ RC = ResourceClass::UAV;
+ Kind = Ty->getDimension();
+ } else if (isa<CBufferExtType>(HandleTy)) {
+ RC = ResourceClass::CBuffer;
+ Kind = ResourceKind::CBuffer;
+ } else if (isa<SamplerExtType>(HandleTy)) {
+ RC = ResourceClass::Sampler;
+ Kind = ResourceKind::Sampler;
+ } else
+ llvm_unreachable("Unknown handle type");
+}
+
bool ResourceInfo::isUAV() const { return RC == ResourceClass::UAV; }
-bool ResourceInfo::isCBuffer() const { return RC == ResourceClass::CBuffer; }
+bool ResourceInfo::isCBuffer() const {
+ return RC == ResourceClass::CBuffer;
+}
-bool ResourceInfo::isSampler() const { return RC == ResourceClass::Sampler; }
+bool ResourceInfo::isSampler() const {
+ return RC == ResourceClass::Sampler;
+}
bool ResourceInfo::isStruct() const {
return Kind == ResourceKind::StructuredBuffer;
@@ -197,184 +262,129 @@ bool ResourceInfo::isMultiSample() const {
Kind == ResourceKind::Texture2DMSArray;
}
-ResourceInfo ResourceInfo::SRV(Value *Symbol, StringRef Name,
- ElementType ElementTy, uint32_t ElementCount,
- ResourceKind Kind) {
- ResourceInfo RI(ResourceClass::SRV, Kind, Symbol, Name);
- assert(RI.isTyped() && !(RI.isStruct() || RI.isMultiSample()) &&
- "Invalid ResourceKind for SRV constructor.");
- RI.setTyped(ElementTy, ElementCount);
- return RI;
-}
-
-ResourceInfo ResourceInfo::RawBuffer(Value *Symbol, StringRef Name) {
- ResourceInfo RI(ResourceClass::SRV, ResourceKind::RawBuffer, Symbol, Name);
- return RI;
-}
-
-ResourceInfo ResourceInfo::StructuredBuffer(Value *Symbol, StringRef Name,
- uint32_t Stride,
- MaybeAlign Alignment) {
- ResourceInfo RI(ResourceClass::SRV, ResourceKind::StructuredBuffer, Symbol,
- Name);
- RI.setStruct(Stride, Alignment);
- return RI;
-}
-
-ResourceInfo ResourceInfo::Texture2DMS(Value *Symbol, StringRef Name,
- ElementType ElementTy,
- uint32_t ElementCount,
- uint32_t SampleCount) {
- ResourceInfo RI(ResourceClass::SRV, ResourceKind::Texture2DMS, Symbol, Name);
- RI.setTyped(ElementTy, ElementCount);
- RI.setMultiSample(SampleCount);
- return RI;
-}
-
-ResourceInfo ResourceInfo::Texture2DMSArray(Value *Symbol, StringRef Name,
- ElementType ElementTy,
- uint32_t ElementCount,
- uint32_t SampleCount) {
- ResourceInfo RI(ResourceClass::SRV, ResourceKind::Texture2DMSArray, Symbol,
- Name);
- RI.setTyped(ElementTy, ElementCount);
- RI.setMultiSample(SampleCount);
- return RI;
-}
-
-ResourceInfo ResourceInfo::UAV(Value *Symbol, StringRef Name,
- ElementType ElementTy, uint32_t ElementCount,
- bool GloballyCoherent, bool IsROV,
- ResourceKind Kind) {
- ResourceInfo RI(ResourceClass::UAV, Kind, Symbol, Name);
- assert(RI.isTyped() && !(RI.isStruct() || RI.isMultiSample()) &&
- "Invalid ResourceKind for UAV constructor.");
- RI.setTyped(ElementTy, ElementCount);
- RI.setUAV(GloballyCoherent, /*HasCounter=*/false, IsROV);
- return RI;
-}
-
-ResourceInfo ResourceInfo::RWRawBuffer(Value *Symbol, StringRef Name,
- bool GloballyCoherent, bool IsROV) {
- ResourceInfo RI(ResourceClass::UAV, ResourceKind::RawBuffer, Symbol, Name);
- RI.setUAV(GloballyCoherent, /*HasCounter=*/false, IsROV);
- return RI;
-}
-
-ResourceInfo ResourceInfo::RWStructuredBuffer(Value *Symbol, StringRef Name,
- uint32_t Stride,
- MaybeAlign Alignment,
- bool GloballyCoherent, bool IsROV,
- bool HasCounter) {
- ResourceInfo RI(ResourceClass::UAV, ResourceKind::StructuredBuffer, Symbol,
- Name);
- RI.setStruct(Stride, Alignment);
- RI.setUAV(GloballyCoherent, HasCounter, IsROV);
- return RI;
-}
-
-ResourceInfo ResourceInfo::RWTexture2DMS(Value *Symbol, StringRef Name,
- ElementType ElementTy,
- uint32_t ElementCount,
- uint32_t SampleCount,
- bool GloballyCoherent) {
- ResourceInfo RI(ResourceClass::UAV, ResourceKind::Texture2DMS, Symbol, Name);
- RI.setTyped(ElementTy, ElementCount);
- RI.setUAV(GloballyCoherent, /*HasCounter=*/false, /*IsROV=*/false);
- RI.setMultiSample(SampleCount);
- return RI;
-}
-
-ResourceInfo ResourceInfo::RWTexture2DMSArray(Value *Symbol, StringRef Name,
- ElementType ElementTy,
- uint32_t ElementCount,
- uint32_t SampleCount,
- bool GloballyCoherent) {
- ResourceInfo RI(ResourceClass::UAV, ResourceKind::Texture2DMSArray, Symbol,
- Name);
- RI.setTyped(ElementTy, ElementCount);
- RI.setUAV(GloballyCoherent, /*HasCounter=*/false, /*IsROV=*/false);
- RI.setMultiSample(SampleCount);
- return RI;
-}
-
-ResourceInfo ResourceInfo::FeedbackTexture2D(Value *Symbol, StringRef Name,
- SamplerFeedbackType FeedbackTy) {
- ResourceInfo RI(ResourceClass::UAV, ResourceKind::FeedbackTexture2D, Symbol,
- Name);
- RI.setUAV(/*GloballyCoherent=*/false, /*HasCounter=*/false, /*IsROV=*/false);
- RI.setFeedback(FeedbackTy);
- return RI;
-}
-
-ResourceInfo
-ResourceInfo::FeedbackTexture2DArray(Value *Symbol, StringRef Name,
- SamplerFeedbackType FeedbackTy) {
- ResourceInfo RI(ResourceClass::UAV, ResourceKind::FeedbackTexture2DArray,
- Symbol, Name);
- RI.setUAV(/*GloballyCoherent=*/false, /*HasCounter=*/false, /*IsROV=*/false);
- RI.setFeedback(FeedbackTy);
- return RI;
-}
-
-ResourceInfo ResourceInfo::CBuffer(Value *Symbol, StringRef Name,
- uint32_t Size) {
- ResourceInfo RI(ResourceClass::CBuffer, ResourceKind::CBuffer, Symbol, Name);
- RI.setCBuffer(Size);
- return RI;
-}
-
-ResourceInfo ResourceInfo::Sampler(Value *Symbol, StringRef Name,
- SamplerType SamplerTy) {
- ResourceInfo RI(ResourceClass::Sampler, ResourceKind::Sampler, Symbol, Name);
- RI.setSampler(SamplerTy);
- return RI;
+static bool isROV(dxil::ResourceKind Kind, TargetExtType *Ty) {
+ switch (Kind) {
+ case ResourceKind::Texture1D:
+ case ResourceKind::Texture2D:
+ case ResourceKind::Texture3D:
+ case ResourceKind::TextureCube:
+ case ResourceKind::Texture1DArray:
+ case ResourceKind::Texture2DArray:
+ case ResourceKind::TextureCubeArray:
+ return cast<TextureExtType>(Ty)->isROV();
+ case ResourceKind::TypedBuffer:
+ return cast<TypedBufferExtType>(Ty)->isROV();
+ case ResourceKind::RawBuffer:
+ case ResourceKind::StructuredBuffer:
+ return cast<RawBufferExtType>(Ty)->isROV();
+ case ResourceKind::Texture2DMS:
+ case ResourceKind::Texture2DMSArray:
+ case ResourceKind::FeedbackTexture2D:
+ case ResourceKind::FeedbackTexture2DArray:
+ return false;
+ case ResourceKind::CBuffer:
+ case ResourceKind::Sampler:
+ case ResourceKind::TBuffer:
+ case ResourceKind::RTAccelerationStructure:
+ case ResourceKind::Invalid:
+ case ResourceKind::NumEntries:
+ llvm_unreachable("Resource cannot be ROV");
+ }
+ llvm_unreachable("Unhandled ResourceKind enum");
+}
+
+ResourceInfo::UAVInfo ResourceInfo::getUAV() const {
+ assert(isUAV() && "Not a UAV");
+ return {GloballyCoherent, HasCounter, isROV(Kind, HandleTy)};
}
-bool ResourceInfo::operator==(const ResourceInfo &RHS) const {
- if (std::tie(Symbol, Name, Binding, RC, Kind) !=
- std::tie(RHS.Symbol, RHS.Name, RHS.Binding, RHS.RC, RHS.Kind))
- return false;
- if (isCBuffer() && RHS.isCBuffer() && CBufferSize != RHS.CBufferSize)
- return false;
- if (isSampler() && RHS.isSampler() && SamplerTy != RHS.SamplerTy)
- return false;
- if (isUAV() && RHS.isUAV() && UAVFlags != RHS.UAVFlags)
- return false;
- if (isStruct() && RHS.isStruct() && Struct != RHS.Struct)
- return false;
- if (isFeedback() && RHS.isFeedback() && Feedback != RHS.Feedback)
- return false;
- if (isTyped() && RHS.isTyped() && Typed != RHS.Typed)
- return false;
- if (isMultiSample() && RHS.isMultiSample() && MultiSample != RHS.MultiSample)
- return false;
- return true;
+uint32_t ResourceInfo::getCBufferSize(const DataLayout &DL) const {
+ assert(isCBuffer() && "Not a CBuffer");
+ Type *Ty = cast<CBufferExtType>(HandleTy)->getResourceType();
+ return DL.getTypeSizeInBits(Ty) / 8;
}
-bool ResourceInfo::operator<(const ResourceInfo &RHS) const {
- // Skip the symbol to avoid non-determinism, and the name to keep a consistent
- // ordering even when we strip reflection data.
- if (std::tie(Binding, RC, Kind) < std::tie(RHS.Binding, RHS.RC, RHS.Kind))
- return true;
- if (isCBuffer() && RHS.isCBuffer() && CBufferSize < RHS.CBufferSize)
- return true;
- if (isSampler() && RHS.isSampler() && SamplerTy < RHS.SamplerTy)
- return true;
- if (isUAV() && RHS.isUAV() && UAVFlags < RHS.UAVFlags)
- return true;
- if (isStruct() && RHS.isStruct() && Struct < RHS.Struct)
- return true;
- if (isFeedback() && RHS.isFeedback() && Feedback < RHS.Feedback)
- return true;
- if (isTyped() && RHS.isTyped() && Typed < RHS.Typed)
- return true;
- if (isMultiSample() && RHS.isMultiSample() && MultiSample < RHS.MultiSample)
- return true;
- return false;
+dxil::SamplerType ResourceInfo::getSamplerType() const {
+ assert(isSampler() && "Not a Sampler");
+ return cast<SamplerExtType>(HandleTy)->getSamplerType();
+}
+
+ResourceInfo::StructInfo
+ResourceInfo::getStruct(const DataLayout &DL) const {
+ assert(isStruct() && "Not a Struct");
+
+ Type *ElTy = cast<RawBufferExtType>(HandleTy)->getResourceType();
+
+ uint32_t Stride = DL.getTypeAllocSize(ElTy);
+ MaybeAlign Alignment;
+ if (auto *STy = dyn_cast<StructType>(ElTy))
+ Alignment = DL.getStructLayout(STy)->getAlignment();
+ uint32_t AlignLog2 = Alignment ? Log2(*Alignment) : 0;
+ return {Stride, AlignLog2};
}
-MDTuple *ResourceInfo::getAsMetadata(LLVMContext &Ctx) const {
+static std::pair<Type *, bool> getTypedElementType(dxil::ResourceKind Kind,
+ TargetExtType *Ty) {
+ switch (Kind) {
+ case ResourceKind::Texture1D:
+ case ResourceKind::Texture2D:
+ case ResourceKind::Texture3D:
+ case ResourceKind::TextureCube:
+ case ResourceKind::Texture1DArray:
+ case ResourceKind::Texture2DArray:
+ case ResourceKind::TextureCubeArray: {
+ auto *RTy = cast<TextureExtType>(Ty);
+ return {RTy->getResourceType(), RTy->isSigned()};
+ }
+ case ResourceKind::Texture2DMS:
+ case ResourceKind::Texture2DMSArray: {
+ auto *RTy = cast<MSTextureExtType>(Ty);
+ return {RTy->getResourceType(), RTy->isSigned()};
+ }
+ case ResourceKind::TypedBuffer: {
+ auto *RTy = cast<TypedBufferExtType>(Ty);
+ return {RTy->getResourceType(), RTy->isSigned()};
+ }
+ case ResourceKind::RawBuffer:
+ case ResourceKind::StructuredBuffer:
+ case ResourceKind::FeedbackTexture2D:
+ case ResourceKind::FeedbackTexture2DArray:
+ case ResourceKind::CBuffer:
+ case ResourceKind::Sampler:
+ case ResourceKind::TBuffer:
+ case ResourceKind::RTAccelerationStructure:
+ case ResourceKind::Invalid:
+ case ResourceKind::NumEntries:
+ llvm_unreachable("Resource is not typed");
+ }
+ llvm_unreachable("Unhandled ResourceKind enum");
+}
+
+ResourceInfo::TypedInfo ResourceInfo::getTyped() const {
+ assert(isTyped() && "Not typed");
+
+ auto [ElTy, IsSigned] = getTypedElementType(Kind, HandleTy);
+ dxil::ElementType ET = toDXILElementType(ElTy, IsSigned);
+ uint32_t Count = 1;
+ if (auto *VTy = dyn_cast<FixedVectorType>(ElTy))
+ Count = VTy->getNumElements();
+ return {ET, Count};
+}
+
+dxil::SamplerFeedbackType ResourceInfo::getFeedbackType() const {
+ assert(isFeedback() && "Not Feedback");
+ return cast<FeedbackTextureExtType>(HandleTy)->getFeedbackType();
+}
+
+uint32_t ResourceInfo::getMultiSampleCount() const {
+ assert(isMultiSample() && "Not MultiSampled");
+ return cast<MSTextureExtType>(HandleTy)->getSampleCount();
+}
+
+MDTuple *ResourceInfo::getAsMetadata(Module &M) const {
+ LLVMContext &Ctx = M.getContext();
+ const DataLayout &DL = M.getDataLayout();
+
SmallVector<Metadata *, 11> MDVals;
Type *I32Ty = Type::getInt32Ty(Ctx);
@@ -389,22 +399,28 @@ MDTuple *ResourceInfo::getAsMetadata(LLVMContext &Ctx) const {
};
MDVals.push_back(getIntMD(Binding.RecordID));
- MDVals.push_back(ValueAsMetadata::get(Symbol));
- MDVals.push_back(MDString::get(Ctx, Name));
+
+ // TODO: We need API to create a symbol of the appropriate type to emit here.
+ // See https://github.com/llvm/llvm-project/issues/116849
+ MDVals.push_back(
+ ValueAsMetadata::get(UndefValue::get(PointerType::getUnqual(Ctx))));
+ MDVals.push_back(MDString::get(Ctx, ""));
+
MDVals.push_back(getIntMD(Binding.Space));
MDVals.push_back(getIntMD(Binding.LowerBound));
MDVals.push_back(getIntMD(Binding.Size));
if (isCBuffer()) {
- MDVals.push_back(getIntMD(CBufferSize));
+ MDVals.push_back(getIntMD(getCBufferSize(DL)));
MDVals.push_back(nullptr);
} else if (isSampler()) {
- MDVals.push_back(getIntMD(llvm::to_underlying(SamplerTy)));
+ MDVals.push_back(getIntMD(llvm::to_underlying(getSamplerType())));
MDVals.push_back(nullptr);
} else {
- MDVals.push_back(getIntMD(llvm::to_underlying(Kind)));
+ MDVals.push_back(getIntMD(llvm::to_underlying(getResourceKind())));
if (isUAV()) {
+ ResourceInfo::UAVInfo UAVFlags = getUAV();
MDVals.push_back(getBoolMD(UAVFlags.GloballyCoherent));
MDVals.push_back(getBoolMD(UAVFlags.HasCounter));
MDVals.push_back(getBoolMD(UAVFlags.IsROV));
@@ -412,7 +428,8 @@ MDTuple *ResourceInfo::getAsMetadata(LLVMContext &Ctx) const {
// All SRVs include sample count in the metadata, but it's only meaningful
// for multi-sampled textured. Also, UAVs can be multisampled in SM6.7+,
// but this just isn't reflected in the metadata at all.
- uint32_t SampleCount = isMultiSample() ? MultiSample.Count : 0;
+ uint32_t SampleCount =
+ isMultiSample() ? getMultiSampleCount() : 0;
MDVals.push_back(getIntMD(SampleCount));
}
@@ -421,14 +438,14 @@ MDTuple *ResourceInfo::getAsMetadata(LLVMContext &Ctx) const {
if (isStruct()) {
Tags.push_back(
getIntMD(llvm::to_underlying(ExtPropTags::StructuredBufferStride)));
- Tags.push_back(getIntMD(Struct.Stride));
+ Tags.push_back(getIntMD(getStruct(DL).Stride));
} else if (isTyped()) {
Tags.push_back(getIntMD(llvm::to_underlying(ExtPropTags::ElementType)));
- Tags.push_back(getIntMD(llvm::to_underlying(Typed.ElementTy)));
+ Tags.push_back(getIntMD(llvm::to_underlying(getTyped().ElementTy)));
} else if (isFeedback()) {
Tags.push_back(
getIntMD(llvm::to_underlying(ExtPropTags::SamplerFeedbackKind)));
- Tags.push_back(getIntMD(llvm::to_underlying(Feedback.Type)));
+ Tags.push_back(getIntMD(llvm::to_underlying(getFeedbackType())));
}
MDVals.push_back(Tags.empty() ? nullptr : MDNode::get(Ctx, Tags));
}
@@ -436,17 +453,21 @@ MDTuple *ResourceInfo::getAsMetadata(LLVMContext &Ctx) const {
return MDNode::get(Ctx, MDVals);
}
-std::pair<uint32_t, uint32_t> ResourceInfo::getAnnotateProps() const {
- uint32_t ResourceKind = llvm::to_underlying(Kind);
- uint32_t AlignLog2 = isStruct() ? Struct.AlignLog2 : 0;
+std::pair<uint32_t, uint32_t> ResourceInfo::getAnnotateProps(Module &M) const {
+ const DataLayout &DL = M.getDataLayout();
+
+ uint32_t ResourceKind = llvm::to_underlying(getResourceKind());
+ uint32_t AlignLog2 = isStruct() ? getStruct(DL).AlignLog2 : 0;
bool IsUAV = isUAV();
+ ResourceInfo::UAVInfo UAVFlags =
+ IsUAV ? getUAV() : ResourceInfo::UAVInfo{};
bool IsROV = IsUAV && UAVFlags.IsROV;
bool IsGloballyCoherent = IsUAV && UAVFlags.GloballyCoherent;
uint8_t SamplerCmpOrHasCounter = 0;
if (IsUAV)
SamplerCmpOrHasCounter = UAVFlags.HasCounter;
else if (isSampler())
- SamplerCmpOrHasCounter = SamplerTy == SamplerType::Comparison;
+ SamplerCmpOrHasCounter = getSamplerType() == SamplerType::Comparison;
// TODO: Document this format. Currently the only reference is the
// implementation of dxc's DxilResourceProperties struct.
@@ -460,15 +481,16 @@ std::pair<uint32_t, uint32_t> ResourceInfo::getAnnotateProps() const {
uint32_t Word1 = 0;
if (isStruct())
- Word1 = Struct.Stride;
+ Word1 = getStruct(DL).Stride;
else if (isCBuffer())
- Word1 = CBufferSize;
+ Word1 = getCBufferSize(DL);
else if (isFeedback())
- Word1 = llvm::to_underlying(Feedback.Type);
+ Word1 = llvm::to_underlying(getFeedbackType());
else if (isTyped()) {
+ ResourceInfo::TypedInfo Typed = getTyped();
uint32_t CompType = llvm::to_underlying(Typed.ElementTy);
uint32_t CompCount = Typed.ElementCount;
- uint32_t SampleCount = isMultiSample() ? MultiSample.Count : 0;
+ uint32_t SampleCount = isMultiSample() ? getMultiSampleCount() : 0;
Word1 |= (CompType & 0xFF) << 0;
Word1 |= (CompCount & 0xFF) << 8;
@@ -478,255 +500,131 @@ std::pair<uint32_t, uint32_t> ResourceInfo::getAnnotateProps() const {
return {Word0, Word1};
}
-void ResourceInfo::print(raw_ostream &OS) const {
- OS << " Symbol: ";
- Symbol->printAsOperand(OS);
- OS << "\n";
+bool ResourceInfo::operator==(const ResourceInfo &RHS) const {
+ return std::tie(Binding, HandleTy, GloballyCoherent, HasCounter) ==
+ std::tie(RHS.Binding, RHS.HandleTy, RHS.GloballyCoherent,
+ RHS.HasCounter);
+}
+
+bool ResourceInfo::operator<(const ResourceInfo &RHS) const {
+ // An empty datalayout is sufficient for sorting purposes.
+ DataLayout DummyDL;
+ if (std::tie(Binding, RC, Kind) < std::tie(RHS.Binding, RHS.RC, RHS.Kind))
+ return true;
+ if (isCBuffer() && RHS.isCBuffer() &&
+ getCBufferSize(DummyDL) < RHS.getCBufferSize(DummyDL))
+ return true;
+ if (isSampler() && RHS.isSampler() && getSamplerType() < RHS.getSamplerType())
+ return true;
+ if (isUAV() && RHS.isUAV() && getUAV() < RHS.getUAV())
+ return true;
+ if (isStruct() && RHS.isStruct() &&
+ getStruct(DummyDL) < RHS.getStruct(DummyDL))
+ return true;
+ if (isFeedback() && RHS.isFeedback() &&
+ getFeedbackType() < RHS.getFeedbackType())
+ return true;
+ if (isTyped() && RHS.isTyped() && getTyped() < RHS.getTyped())
+ return true;
+ if (isMultiSample() && RHS.isMultiSample() &&
+ getMultiSampleCount() < RHS.getMultiSampleCount())
+ return true;
+ return false;
+}
- OS << " Name: \"" << Name << "\"\n"
- << " Binding:\n"
+void ResourceInfo::print(raw_ostream &OS, const DataLayout &DL) const {
+ OS << " Binding:\n"
<< " Record ID: " << Binding.RecordID << "\n"
<< " Space: " << Binding.Space << "\n"
<< " Lower Bound: " << Binding.LowerBound << "\n"
- << " Size: " << Binding.Size << "\n"
- << " Class: " << getResourceClassName(RC) << "\n"
+ << " Size: " << Binding.Size << "\n";
+
+ OS << " Class: " << getResourceClassName(RC) << "\n"
<< " Kind: " << getResourceKindName(Kind) << "\n";
if (isCBuffer()) {
- OS << " CBuffer size: " << CBufferSize << "\n";
+ OS << " CBuffer size: " << getCBufferSize(DL) << "\n";
} else if (isSampler()) {
- OS << " Sampler Type: " << getSamplerTypeName(SamplerTy) << "\n";
+ OS << " Sampler Type: " << getSamplerTypeName(getSamplerType()) << "\n";
} else {
if (isUAV()) {
+ UAVInfo UAVFlags = getUAV();
OS << " Globally Coherent: " << UAVFlags.GloballyCoherent << "\n"
<< " HasCounter: " << UAVFlags.HasCounter << "\n"
<< " IsROV: " << UAVFlags.IsROV << "\n";
}
if (isMultiSample())
- OS << " Sample Count: " << MultiSample.Count << "\n";
+ OS << " Sample Count: " << getMultiSampleCount() << "\n";
if (isStruct()) {
+ StructInfo Struct = getStruct(DL);
OS << " Buffer Stride: " << Struct.Stride << "\n";
OS << " Alignment: " << Struct.AlignLog2 << "\n";
} else if (isTyped()) {
+ TypedInfo Typed = getTyped();
OS << " Element Type: " << getElementTypeName(Typed.ElementTy) << "\n"
<< " Element Count: " << Typed.ElementCount << "\n";
} else if (isFeedback())
- OS << " Feedback Type: " << getSamplerFeedbackTypeName(Feedback.Type)
+ OS << " Feedback Type: " << getSamplerFeedbackTypeName(getFeedbackType())
<< "\n";
}
}
//===----------------------------------------------------------------------===//
-// ResourceMapper
-static dxil::ElementType toDXILElementType(Type *Ty, bool IsSigned) {
- // TODO: Handle unorm, snorm, and packed.
- Ty = Ty->getScalarType();
+void DXILResourceMap::populate(Module &M) {
+ SmallVector<std::pair<CallInst *, ResourceInfo>> CIToInfo;
- if (Ty->isIntegerTy()) {
- switch (Ty->getIntegerBitWidth()) {
- case 16:
- return IsSigned ? ElementType::I16 : ElementType::U16;
- case 32:
- return IsSigned ? ElementType::I32 : ElementType::U32;
- case 64:
- return IsSigned ? ElementType::I64 : ElementType::U64;
- case 1:
+ for (Function &F : M.functions()) {
+ if (!F.isDeclaration())
+ continue;
+ LLVM_DEBUG(dbgs() << "Function: " << F.getName() << "\n");
+ Intrinsic::ID ID = F.getIntrinsicID();
+ switch (ID) {
default:
- return ElementType::Invalid;
- }
- } else if (Ty->isFloatTy()) {
- return ElementType::F32;
- } else if (Ty->isDoubleTy()) {
- return ElementType::F64;
- } else if (Ty->isHalfTy()) {
- return ElementType::F16;
- }
-
- return ElementType::Invalid;
-}
-
-namespace {
-
-class ResourceMapper {
- Module &M;
- LLVMContext &Context;
- SmallVector<std::pair<CallInst *, dxil::ResourceInfo>> Resources;
-
-public:
- ResourceMapper(Module &M) : M(M), Context(M.getContext()) {}
-
- void diagnoseHandle(CallInst *CI, const Twine &Msg,
- DiagnosticSeverity Severity = DS_Error) {
- std::string S;
- raw_string_ostream SS(S);
- CI->printAsOperand(SS);
- DiagnosticInfoUnsupported Diag(*CI->getFunction(), Msg + ": " + SS.str(),
- CI->getDebugLoc(), Severity);
- Context.diagnose(Diag);
- }
-
- ResourceInfo *mapBufferType(CallInst *CI, TargetExtType *HandleTy,
- bool IsTyped) {
- if (HandleTy->getNumTypeParameters() != 1 ||
- HandleTy->getNumIntParameters() != (IsTyped ? 3 : 2)) {
- diagnoseHandle(CI, Twine("Invalid buffer target type"));
- return nullptr;
- }
-
- Type *ElTy = HandleTy->getTypeParameter(0);
- unsigned IsWriteable = HandleTy->getIntParameter(0);
- unsigned IsROV = HandleTy->getIntParameter(1);
- bool IsSigned = IsTyped && HandleTy->getIntParameter(2);
-
- ResourceClass RC = IsWriteable ? ResourceClass::UAV : ResourceClass::SRV;
- ResourceKind Kind;
- if (IsTyped)
- Kind = ResourceKind::TypedBuffer;
- else if (ElTy->isIntegerTy(8))
- Kind = ResourceKind::RawBuffer;
- else
- Kind = ResourceKind::StructuredBuffer;
-
- // TODO: We need to lower to a typed pointer, can we smuggle the type
- // through?
- Value *Symbol = UndefValue::get(PointerType::getUnqual(Context));
- // TODO: We don't actually keep track of the name right now...
- StringRef Name = "";
-
- // Note that we return a pointer into the vector's storage. This is okay as
- // long as we don't add more elements until we're done with the pointer.
- auto &Pair =
- Resources.emplace_back(CI, ResourceInfo{RC, Kind, Symbol, Name});
- ResourceInfo *RI = &Pair.second;
-
- if (RI->isUAV())
- // TODO: We need analysis for GloballyCoherent and HasCounter
- RI->setUAV(false, false, IsROV);
-
- if (RI->isTyped()) {
- dxil::ElementType ET = toDXILElementType(ElTy, IsSigned);
- uint32_t Count = 1;
- if (auto *VTy = dyn_cast<FixedVectorType>(ElTy))
- Count = VTy->getNumElements();
- RI->setTyped(ET, Count);
- } else if (RI->isStruct()) {
- const DataLayout &DL = M.getDataLayout();
-
- // This mimics what DXC does. Notably, we only ever set the alignment if
- // the type is actually a struct type.
- uint32_t Stride = DL.getTypeAllocSize(ElTy);
- MaybeAlign Alignment;
- if (auto *STy = dyn_cast<StructType>(ElTy))
- Alignment = DL.getStructLayout(STy)->getAlignment();
- RI->setStruct(Stride, Alignment);
- }
-
- return RI;
- }
+ continue;
+ case Intrinsic::dx_handle_fromBinding: {
+ auto *HandleTy = cast<TargetExtType>(F.getReturnType());
- ResourceInfo *mapHandleIntrin(CallInst *CI) {
- FunctionType *FTy = CI->getFunctionType();
- Type *RetTy = FTy->getReturnType();
- auto *HandleTy = dyn_cast<TargetExtType>(RetTy);
- if (!HandleTy) {
- diagnoseHandle(CI, "dx.handle.fromBinding requires target type");
- return nullptr;
- }
-
- StringRef TypeName = HandleTy->getName();
- if (TypeName == "dx.TypedBuffer") {
- return mapBufferType(CI, HandleTy, /*IsTyped=*/true);
- } else if (TypeName == "dx.RawBuffer") {
- return mapBufferType(CI, HandleTy, /*IsTyped=*/false);
- } else if (TypeName == "dx.CBuffer") {
- // TODO: implement
- diagnoseHandle(CI, "dx.CBuffer handles are not implemented yet");
- return nullptr;
- } else if (TypeName == "dx.Sampler") {
- // TODO: implement
- diagnoseHandle(CI, "dx.Sampler handles are not implemented yet");
- return nullptr;
- } else if (TypeName == "dx.Texture") {
- // TODO: implement
- diagnoseHandle(CI, "dx.Texture handles are not implemented yet");
- return nullptr;
- }
-
- diagnoseHandle(CI, "Invalid target(dx) type");
- return nullptr;
- }
-
- ResourceInfo *mapHandleFromBinding(CallInst *CI) {
- assert(CI->getIntrinsicID() == Intrinsic::dx_handle_fromBinding &&
- "Must be dx.handle.fromBinding intrinsic");
-
- ResourceInfo *RI = mapHandleIntrin(CI);
- if (!RI)
- return nullptr;
-
- uint32_t Space = cast<ConstantInt>(CI->getArgOperand(0))->getZExtValue();
- uint32_t LowerBound =
- cast<ConstantInt>(CI->getArgOperand(1))->getZExtValue();
- uint32_t Size = cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
-
- // We use a binding ID of zero for now - these will be filled in later.
- RI->bind(0U, Space, LowerBound, Size);
-
- return RI;
- }
-
- DXILResourceMap mapResources() {
- for (Function &F : M.functions()) {
- if (!F.isDeclaration())
- continue;
- LLVM_DEBUG(dbgs() << "Function: " << F.getName() << "\n");
- Intrinsic::ID ID = F.getIntrinsicID();
- switch (ID) {
- default:
- // TODO: handle `dx.op` functions.
- continue;
- case Intrinsic::dx_handle_fromBinding:
- for (User *U : F.users()) {
+ for (User *U : F.users())
+ if (CallInst *CI = dyn_cast<CallInst>(U)) {
LLVM_DEBUG(dbgs() << " Visiting: " << *U << "\n");
- if (CallInst *CI = dyn_cast<CallInst>(U))
- mapHandleFromBinding(CI);
+ uint32_t Space =
+ cast<ConstantInt>(CI->getArgOperand(0))->getZExtValue();
+ uint32_t LowerBound =
+ cast<ConstantInt>(CI->getArgOperand(1))->getZExtValue();
+ uint32_t Size =
+ cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
+ ResourceInfo RI =
+ ResourceInfo{/*RecordID=*/0, Space, LowerBound, Size, HandleTy};
+
+ CIToInfo.emplace_back(CI, RI);
}
- break;
- }
- }
- return DXILResourceMap(std::move(Resources));
+ break;
+ }
+ }
}
-};
-
-} // namespace
-
-DXILResourceMap::DXILResourceMap(
- SmallVectorImpl<std::pair<CallInst *, dxil::ResourceInfo>> &&CIToRI) {
- if (CIToRI.empty())
- return;
- llvm::stable_sort(CIToRI, [](auto &LHS, auto &RHS) {
+ llvm::stable_sort(CIToInfo, [](auto &LHS, auto &RHS) {
// Sort by resource class first for grouping purposes, and then by the rest
// of the fields so that we can remove duplicates.
ResourceClass LRC = LHS.second.getResourceClass();
ResourceClass RRC = RHS.second.getResourceClass();
return std::tie(LRC, LHS.second) < std::tie(RRC, RHS.second);
});
- for (auto [CI, RI] : CIToRI) {
- if (Resources.empty() || RI != Resources.back())
- Resources.push_back(RI);
- CallMap[CI] = Resources.size() - 1;
+ for (auto [CI, RI] : CIToInfo) {
+ if (Infos.empty() || RI != Infos.back())
+ Infos.push_back(RI);
+ CallMap[CI] = Infos.size() - 1;
}
- unsigned Size = Resources.size();
+ unsigned Size = Infos.size();
// In DXC, Record ID is unique per resource type. Match that.
FirstUAV = FirstCBuffer = FirstSampler = Size;
uint32_t NextID = 0;
for (unsigned I = 0, E = Size; I != E; ++I) {
- ResourceInfo &RI = Resources[I];
+ ResourceInfo &RI = Infos[I];
if (RI.isUAV() && FirstUAV == Size) {
FirstUAV = I;
NextID = 0;
@@ -739,15 +637,14 @@ DXILResourceMap::DXILResourceMap(
}
// Adjust the resource binding to use the next ID.
- const ResourceInfo::ResourceBinding &Binding = RI.getBinding();
- RI.bind(NextID++, Binding.Space, Binding.LowerBound, Binding.Size);
+ RI.setBindingID(NextID++);
}
}
-void DXILResourceMap::print(raw_ostream &OS) const {
- for (unsigned I = 0, E = Resources.size(); I != E; ++I) {
+void DXILResourceMap::print(raw_ostream &OS, const DataLayout &DL) const {
+ for (unsigned I = 0, E = Infos.size(); I != E; ++I) {
OS << "Binding " << I << ":\n";
- Resources[I].print(OS);
+ Infos[I].print(OS, DL);
OS << "\n";
}
@@ -759,27 +656,24 @@ void DXILResourceMap::print(raw_ostream &OS) const {
}
//===----------------------------------------------------------------------===//
-// DXILResourceAnalysis and DXILResourcePrinterPass
-// Provide an explicit template instantiation for the static ID.
AnalysisKey DXILResourceAnalysis::Key;
DXILResourceMap DXILResourceAnalysis::run(Module &M,
ModuleAnalysisManager &AM) {
- DXILResourceMap Data = ResourceMapper(M).mapResources();
+ DXILResourceMap Data;
+ Data.populate(M);
return Data;
}
PreservedAnalyses DXILResourcePrinterPass::run(Module &M,
ModuleAnalysisManager &AM) {
- DXILResourceMap &DRM = AM.getResult<DXILResourceAnalysis>(M);
- DRM.print(OS);
+ DXILResourceMap &DBM = AM.getResult<DXILResourceAnalysis>(M);
+
+ DBM.print(OS, M.getDataLayout());
return PreservedAnalyses::all();
}
-//===----------------------------------------------------------------------===//
-// DXILResourceWrapperPass
-
DXILResourceWrapperPass::DXILResourceWrapperPass() : ModulePass(ID) {
initializeDXILResourceWrapperPassPass(*PassRegistry::getPassRegistry());
}
@@ -791,18 +685,21 @@ void DXILResourceWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
}
bool DXILResourceWrapperPass::runOnModule(Module &M) {
- ResourceMap.reset(new DXILResourceMap(ResourceMapper(M).mapResources()));
+ Map.reset(new DXILResourceMap());
+
+ Map->populate(M);
+
return false;
}
-void DXILResourceWrapperPass::releaseMemory() { ResourceMap.reset(); }
+void DXILResourceWrapperPass::releaseMemory() { Map.reset(); }
-void DXILResourceWrapperPass::print(raw_ostream &OS, const Module *) const {
- if (!ResourceMap) {
+void DXILResourceWrapperPass::print(raw_ostream &OS, const Module *M) const {
+ if (!Map) {
OS << "No resource map has been built!\n";
return;
}
- ResourceMap->print(OS);
+ Map->print(OS, M->getDataLayout());
}
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
@@ -810,8 +707,8 @@ LLVM_DUMP_METHOD
void DXILResourceWrapperPass::dump() const { print(dbgs(), nullptr); }
#endif
-INITIALIZE_PASS(DXILResourceWrapperPass, DEBUG_TYPE, "DXIL Resource analysis",
- false, true)
+INITIALIZE_PASS(DXILResourceWrapperPass, "dxil-resource-binding",
+ "DXIL Resource analysis", false, true)
char DXILResourceWrapperPass::ID = 0;
ModulePass *llvm::createDXILResourceWrapperPassPass() {
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index d9e70da6ed653a..78efdcf194b6c6 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -304,7 +304,7 @@ class OpLowerer {
IndexOp = IRB.CreateAdd(IndexOp,
ConstantInt::get(Int32Ty, Binding.LowerBound));
- std::pair<uint32_t, uint32_t> Props = RI.getAnnotateProps();
+ std::pair<uint32_t, uint32_t> Props = RI.getAnnotateProps(*F.getParent());
// For `CreateHandleFromBinding` we need the upper bound rather than the
// size, so we need to be careful about the difference for "unbounded".
diff --git a/llvm/lib/Target/DirectX/DXILPrettyPrinter.cpp b/llvm/lib/Target/DirectX/DXILPrettyPrinter.cpp
index 0478dc2df988de..4aa25b3996e3c1 100644
--- a/llvm/lib/Target/DirectX/DXILPrettyPrinter.cpp
+++ b/llvm/lib/Target/DirectX/DXILPrettyPrinter.cpp
@@ -149,7 +149,7 @@ struct FormatResourceDimension
default: {
OS << getTextureDimName(RK);
if (Item.isMultiSample())
- OS << Item.getMultiSample().Count;
+ OS << Item.getMultiSampleCount();
break;
}
case dxil::ResourceKind::RawBuffer:
diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index 4ba10d123e8d27..baefadede6e3ab 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -78,13 +78,13 @@ static NamedMDNode *emitResourceMetadata(Module &M, const DXILResourceMap &DRM,
SmallVector<Metadata *> SRVs, UAVs, CBufs, Smps;
for (const ResourceInfo &RI : DRM.srvs())
- SRVs.push_back(RI.getAsMetadata(Context));
+ SRVs.push_back(RI.getAsMetadata(M));
for (const ResourceInfo &RI : DRM.uavs())
- UAVs.push_back(RI.getAsMetadata(Context));
+ UAVs.push_back(RI.getAsMetadata(M));
for (const ResourceInfo &RI : DRM.cbuffers())
- CBufs.push_back(RI.getAsMetadata(Context));
+ CBufs.push_back(RI.getAsMetadata(M));
for (const ResourceInfo &RI : DRM.samplers())
- Smps.push_back(RI.getAsMetadata(Context));
+ Smps.push_back(RI.getAsMetadata(M));
Metadata *SRVMD = SRVs.empty() ? nullptr : MDNode::get(Context, SRVs);
Metadata *UAVMD = UAVs.empty() ? nullptr : MDNode::get(Context, UAVs);
diff --git a/llvm/unittests/Analysis/DXILResourceTest.cpp b/llvm/unittests/Analysis/DXILResourceTest.cpp
index e24018457dabec..2122f1a91cfc96 100644
--- a/llvm/unittests/Analysis/DXILResourceTest.cpp
+++ b/llvm/unittests/Analysis/DXILResourceTest.cpp
@@ -8,6 +8,9 @@
#include "llvm/Analysis/DXILResource.h"
#include "llvm/IR/Constants.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/Module.h"
#include "gtest/gtest.h"
using namespace llvm;
@@ -99,8 +102,16 @@ testing::AssertionResult MDTupleEq(const char *LHSExpr, const char *RHSExpr,
} // namespace
TEST(DXILResource, AnnotationsAndMetadata) {
+ // TODO: How am I supposed to get this?
+ DataLayout DL("e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-"
+ "f64:64-n8:16:32:64-v96:32");
+
LLVMContext Context;
+ Module M("AnnotationsAndMetadata", Context);
+ M.setDataLayout(DL);
+
Type *Int1Ty = Type::getInt1Ty(Context);
+ Type *Int8Ty = Type::getInt8Ty(Context);
Type *Int32Ty = Type::getInt32Ty(Context);
Type *FloatTy = Type::getFloatTy(Context);
Type *DoubleTy = Type::getDoubleTy(Context);
@@ -109,206 +120,310 @@ TEST(DXILResource, AnnotationsAndMetadata) {
Type *Int32x2Ty = FixedVectorType::get(Int32Ty, 2);
MDBuilder TestMD(Context, Int32Ty, Int1Ty);
+ Value *DummyGV = UndefValue::get(PointerType::getUnqual(Context));
+
+ // ByteAddressBuffer Buffer;
+ auto *HandleTy = llvm::TargetExtType::get(Context, "dx.RawBuffer", Int8Ty,
+ {/*IsWriteable=*/0, /*IsROV=*/0});
+ ResourceInfo RI(
+ /*RecordID=*/0, /*Space=*/0, /*LowerBound=*/0, /*Size=*/1, HandleTy);
- // ByteAddressBuffer Buffer0;
- Value *Symbol = UndefValue::get(
- StructType::create(Context, {Int32Ty}, "struct.ByteAddressBuffer"));
- ResourceInfo Resource = ResourceInfo::RawBuffer(Symbol, "Buffer0");
- Resource.bind(0, 0, 0, 1);
- std::pair<uint32_t, uint32_t> Props = Resource.getAnnotateProps();
+ EXPECT_EQ(RI.getResourceClass(), ResourceClass::SRV);
+ EXPECT_EQ(RI.getResourceKind(), ResourceKind::RawBuffer);
+
+ std::pair<uint32_t, uint32_t> Props = RI.getAnnotateProps(M);
EXPECT_EQ(Props.first, 0x0000000bU);
EXPECT_EQ(Props.second, 0U);
- MDTuple *MD = Resource.getAsMetadata(Context);
- EXPECT_MDEQ(MD, TestMD.get(0, Symbol, "Buffer0", 0, 0, 1, 11, 0, nullptr));
+ MDTuple *MD = RI.getAsMetadata(M);
+ EXPECT_MDEQ(MD, TestMD.get(0, DummyGV, "", 0, 0, 1, 11, 0, nullptr));
// RWByteAddressBuffer BufferOut : register(u3, space2);
- Symbol = UndefValue::get(
- StructType::create(Context, {Int32Ty}, "struct.RWByteAddressBuffer"));
- Resource =
- ResourceInfo::RWRawBuffer(Symbol, "BufferOut",
- /*GloballyCoherent=*/false, /*IsROV=*/false);
- Resource.bind(1, 2, 3, 1);
- Props = Resource.getAnnotateProps();
+ HandleTy = llvm::TargetExtType::get(Context, "dx.RawBuffer", Int8Ty,
+ {/*IsWriteable=*/1, /*IsROV=*/0});
+ RI = ResourceInfo(
+ /*RecordID=*/1, /*Space=*/2, /*LowerBound=*/3, /*Size=*/1, HandleTy);
+
+ EXPECT_EQ(RI.getResourceClass(), ResourceClass::UAV);
+ EXPECT_EQ(RI.getUAV().GloballyCoherent, false);
+ EXPECT_EQ(RI.getUAV().HasCounter, false);
+ EXPECT_EQ(RI.getUAV().IsROV, false);
+ EXPECT_EQ(RI.getResourceKind(), ResourceKind::RawBuffer);
+
+ Props = RI.getAnnotateProps(M);
EXPECT_EQ(Props.first, 0x0000100bU);
EXPECT_EQ(Props.second, 0U);
- MD = Resource.getAsMetadata(Context);
- EXPECT_MDEQ(MD, TestMD.get(1, Symbol, "BufferOut", 2, 3, 1, 11, false, false,
- false, nullptr));
+ MD = RI.getAsMetadata(M);
+ EXPECT_MDEQ(MD, TestMD.get(1, DummyGV, "", 2, 3, 1, 11, false, false, false,
+ nullptr));
// struct BufType0 { int i; float f; double d; };
// StructuredBuffer<BufType0> Buffer0 : register(t0);
StructType *BufType0 =
StructType::create(Context, {Int32Ty, FloatTy, DoubleTy}, "BufType0");
- Symbol = UndefValue::get(StructType::create(
- Context, {BufType0}, "class.StructuredBuffer<BufType>"));
- Resource = ResourceInfo::StructuredBuffer(Symbol, "Buffer0",
- /*Stride=*/16, Align(8));
- Resource.bind(0, 0, 0, 1);
- Props = Resource.getAnnotateProps();
+ HandleTy = llvm::TargetExtType::get(Context, "dx.RawBuffer", BufType0,
+ {/*IsWriteable=*/0, /*IsROV=*/0});
+ RI = ResourceInfo(
+ /*RecordID=*/0, /*Space=*/0, /*LowerBound=*/0, /*Size=*/1, HandleTy);
+
+ EXPECT_EQ(RI.getResourceClass(), ResourceClass::SRV);
+ ASSERT_EQ(RI.isStruct(), true);
+ EXPECT_EQ(RI.getStruct(DL).Stride, 16u);
+ EXPECT_EQ(RI.getStruct(DL).AlignLog2, Log2(Align(8)));
+ EXPECT_EQ(RI.getResourceKind(), ResourceKind::StructuredBuffer);
+
+ Props = RI.getAnnotateProps(M);
EXPECT_EQ(Props.first, 0x0000030cU);
EXPECT_EQ(Props.second, 0x00000010U);
- MD = Resource.getAsMetadata(Context);
- EXPECT_MDEQ(
- MD, TestMD.get(0, Symbol, "Buffer0", 0, 0, 1, 12, 0, TestMD.get(1, 16)));
+ MD = RI.getAsMetadata(M);
+ EXPECT_MDEQ(MD,
+ TestMD.get(0, DummyGV, "", 0, 0, 1, 12, 0, TestMD.get(1, 16)));
// StructuredBuffer<float3> Buffer1 : register(t1);
- Symbol = UndefValue::get(StructType::create(
- Context, {Floatx3Ty}, "class.StructuredBuffer<vector<float, 3> >"));
- Resource = ResourceInfo::StructuredBuffer(Symbol, "Buffer1",
- /*Stride=*/12, {});
- Resource.bind(1, 0, 1, 1);
- Props = Resource.getAnnotateProps();
+ HandleTy = llvm::TargetExtType::get(Context, "dx.RawBuffer", Floatx3Ty,
+ {/*IsWriteable=*/0, /*IsROV=*/0});
+ RI = ResourceInfo(
+ /*RecordID=*/1, /*Space=*/0, /*LowerBound=*/1, /*Size=*/1, HandleTy);
+
+ EXPECT_EQ(RI.getResourceClass(), ResourceClass::SRV);
+ ASSERT_EQ(RI.isStruct(), true);
+ EXPECT_EQ(RI.getStruct(DL).Stride, 12u);
+ EXPECT_EQ(RI.getStruct(DL).AlignLog2, 0u);
+ EXPECT_EQ(RI.getResourceKind(), ResourceKind::StructuredBuffer);
+
+ Props = RI.getAnnotateProps(M);
EXPECT_EQ(Props.first, 0x0000000cU);
EXPECT_EQ(Props.second, 0x0000000cU);
- MD = Resource.getAsMetadata(Context);
- EXPECT_MDEQ(
- MD, TestMD.get(1, Symbol, "Buffer1", 0, 1, 1, 12, 0, TestMD.get(1, 12)));
+ MD = RI.getAsMetadata(M);
+ EXPECT_MDEQ(MD,
+ TestMD.get(1, DummyGV, "", 0, 1, 1, 12, 0, TestMD.get(1, 12)));
// Texture2D<float4> ColorMapTexture : register(t2);
- Symbol = UndefValue::get(StructType::create(
- Context, {Floatx4Ty}, "class.Texture2D<vector<float, 4> >"));
- Resource =
- ResourceInfo::SRV(Symbol, "ColorMapTexture", dxil::ElementType::F32,
- /*ElementCount=*/4, dxil::ResourceKind::Texture2D);
- Resource.bind(2, 0, 2, 1);
- Props = Resource.getAnnotateProps();
+ HandleTy =
+ llvm::TargetExtType::get(Context, "dx.Texture", Floatx4Ty,
+ {/*IsWriteable=*/0, /*IsROV=*/0, /*IsSigned=*/0,
+ llvm::to_underlying(ResourceKind::Texture2D)});
+ RI = ResourceInfo(
+ /*RecordID=*/2, /*Space=*/0, /*LowerBound=*/2, /*Size=*/1, HandleTy);
+
+ EXPECT_EQ(RI.getResourceClass(), ResourceClass::SRV);
+ ASSERT_EQ(RI.isTyped(), true);
+ EXPECT_EQ(RI.getTyped().ElementTy, ElementType::F32);
+ EXPECT_EQ(RI.getTyped().ElementCount, 4u);
+ EXPECT_EQ(RI.getResourceKind(), ResourceKind::Texture2D);
+
+ Props = RI.getAnnotateProps(M);
EXPECT_EQ(Props.first, 0x00000002U);
EXPECT_EQ(Props.second, 0x00000409U);
- MD = Resource.getAsMetadata(Context);
- EXPECT_MDEQ(MD, TestMD.get(2, Symbol, "ColorMapTexture", 0, 2, 1, 2, 0,
- TestMD.get(0, 9)));
+ MD = RI.getAsMetadata(M);
+ EXPECT_MDEQ(MD, TestMD.get(2, DummyGV, "", 0, 2, 1, 2, 0, TestMD.get(0, 9)));
// Texture2DMS<float, 8> DepthBuffer : register(t0);
- Symbol = UndefValue::get(
- StructType::create(Context, {FloatTy}, "class.Texture2DMS<float, 8>"));
- Resource =
- ResourceInfo::Texture2DMS(Symbol, "DepthBuffer", dxil::ElementType::F32,
- /*ElementCount=*/1, /*SampleCount=*/8);
- Resource.bind(0, 0, 0, 1);
- Props = Resource.getAnnotateProps();
+ HandleTy = llvm::TargetExtType::get(
+ Context, "dx.MSTexture", FloatTy,
+ {/*IsWriteable=*/0, /*SampleCount=*/8,
+ /*IsSigned=*/0, llvm::to_underlying(ResourceKind::Texture2DMS)});
+ RI = ResourceInfo(
+ /*RecordID=*/0, /*Space=*/0, /*LowerBound=*/0, /*Size=*/1, HandleTy);
+
+ EXPECT_EQ(RI.getResourceClass(), ResourceClass::SRV);
+ ASSERT_EQ(RI.isTyped(), true);
+ EXPECT_EQ(RI.getTyped().ElementTy, ElementType::F32);
+ EXPECT_EQ(RI.getTyped().ElementCount, 1u);
+ ASSERT_EQ(RI.isMultiSample(), true);
+ EXPECT_EQ(RI.getMultiSampleCount(), 8u);
+ EXPECT_EQ(RI.getResourceKind(), ResourceKind::Texture2DMS);
+
+ Props = RI.getAnnotateProps(M);
EXPECT_EQ(Props.first, 0x00000003U);
EXPECT_EQ(Props.second, 0x00080109U);
- MD = Resource.getAsMetadata(Context);
- EXPECT_MDEQ(MD, TestMD.get(0, Symbol, "DepthBuffer", 0, 0, 1, 3, 8,
- TestMD.get(0, 9)));
+ MD = RI.getAsMetadata(M);
+ EXPECT_MDEQ(MD, TestMD.get(0, DummyGV, "", 0, 0, 1, 3, 8, TestMD.get(0, 9)));
// FeedbackTexture2D<SAMPLER_FEEDBACK_MIN_MIP> feedbackMinMip;
- Symbol = UndefValue::get(
- StructType::create(Context, {Int32Ty}, "class.FeedbackTexture2D<0>"));
- Resource = ResourceInfo::FeedbackTexture2D(Symbol, "feedbackMinMip",
- SamplerFeedbackType::MinMip);
- Resource.bind(0, 0, 0, 1);
- Props = Resource.getAnnotateProps();
+ HandleTy = llvm::TargetExtType::get(
+ Context, "dx.FeedbackTexture", {},
+ {llvm::to_underlying(SamplerFeedbackType::MinMip),
+ llvm::to_underlying(ResourceKind::FeedbackTexture2D)});
+ RI = ResourceInfo(
+ /*RecordID=*/0, /*Space=*/0, /*LowerBound=*/0, /*Size=*/1, HandleTy);
+
+ EXPECT_EQ(RI.getResourceClass(), ResourceClass::UAV);
+ ASSERT_EQ(RI.isFeedback(), true);
+ EXPECT_EQ(RI.getFeedbackType(), SamplerFeedbackType::MinMip);
+ EXPECT_EQ(RI.getResourceKind(), ResourceKind::FeedbackTexture2D);
+
+ Props = RI.getAnnotateProps(M);
EXPECT_EQ(Props.first, 0x00001011U);
EXPECT_EQ(Props.second, 0U);
- MD = Resource.getAsMetadata(Context);
- EXPECT_MDEQ(MD, TestMD.get(0, Symbol, "feedbackMinMip", 0, 0, 1, 17, false,
- false, false, TestMD.get(2, 0)));
+ MD = RI.getAsMetadata(M);
+ EXPECT_MDEQ(MD, TestMD.get(0, DummyGV, "", 0, 0, 1, 17, false, false, false,
+ TestMD.get(2, 0)));
// FeedbackTexture2DArray<SAMPLER_FEEDBACK_MIP_REGION_USED> feedbackMipRegion;
- Symbol = UndefValue::get(StructType::create(
- Context, {Int32Ty}, "class.FeedbackTexture2DArray<1>"));
- Resource = ResourceInfo::FeedbackTexture2DArray(
- Symbol, "feedbackMipRegion", SamplerFeedbackType::MipRegionUsed);
- Resource.bind(0, 0, 0, 1);
- Props = Resource.getAnnotateProps();
+ HandleTy = llvm::TargetExtType::get(
+ Context, "dx.FeedbackTexture", {},
+ {llvm::to_underlying(SamplerFeedbackType::MipRegionUsed),
+ llvm::to_underlying(ResourceKind::FeedbackTexture2DArray)});
+ RI = ResourceInfo(
+ /*RecordID=*/0, /*Space=*/0, /*LowerBound=*/0, /*Size=*/1, HandleTy);
+
+ EXPECT_EQ(RI.getResourceClass(), ResourceClass::UAV);
+ ASSERT_EQ(RI.isFeedback(), true);
+ EXPECT_EQ(RI.getFeedbackType(), SamplerFeedbackType::MipRegionUsed);
+ EXPECT_EQ(RI.getResourceKind(), ResourceKind::FeedbackTexture2DArray);
+
+ Props = RI.getAnnotateProps(M);
EXPECT_EQ(Props.first, 0x00001012U);
EXPECT_EQ(Props.second, 0x00000001U);
- MD = Resource.getAsMetadata(Context);
- EXPECT_MDEQ(MD, TestMD.get(0, Symbol, "feedbackMipRegion", 0, 0, 1, 18, false,
- false, false, TestMD.get(2, 1)));
+ MD = RI.getAsMetadata(M);
+ EXPECT_MDEQ(MD, TestMD.get(0, DummyGV, "", 0, 0, 1, 18, false, false, false,
+ TestMD.get(2, 1)));
// globallycoherent RWTexture2D<int2> OutputTexture : register(u0, space2);
- Symbol = UndefValue::get(StructType::create(
- Context, {Int32x2Ty}, "class.RWTexture2D<vector<int, 2> >"));
- Resource = ResourceInfo::UAV(Symbol, "OutputTexture", dxil::ElementType::I32,
- /*ElementCount=*/2, /*GloballyCoherent=*/1,
- /*IsROV=*/0, dxil::ResourceKind::Texture2D);
- Resource.bind(0, 2, 0, 1);
- Props = Resource.getAnnotateProps();
+ HandleTy =
+ llvm::TargetExtType::get(Context, "dx.Texture", Int32x2Ty,
+ {/*IsWriteable=*/1,
+ /*IsROV=*/0, /*IsSigned=*/1,
+ llvm::to_underlying(ResourceKind::Texture2D)});
+
+ RI = ResourceInfo(
+ /*RecordID=*/0, /*Space=*/2, /*LowerBound=*/0, /*Size=*/1, HandleTy,
+ /*GloballyCoherent=*/true, /*HasCounter=*/false);
+
+ EXPECT_EQ(RI.getResourceClass(), ResourceClass::UAV);
+ EXPECT_EQ(RI.getUAV().GloballyCoherent, true);
+ EXPECT_EQ(RI.getUAV().HasCounter, false);
+ EXPECT_EQ(RI.getUAV().IsROV, false);
+ EXPECT_EQ(RI.getResourceKind(), ResourceKind::Texture2D);
+
+ Props = RI.getAnnotateProps(M);
EXPECT_EQ(Props.first, 0x00005002U);
EXPECT_EQ(Props.second, 0x00000204U);
- MD = Resource.getAsMetadata(Context);
- EXPECT_MDEQ(MD, TestMD.get(0, Symbol, "OutputTexture", 2, 0, 1, 2, true,
- false, false, TestMD.get(0, 4)));
+ MD = RI.getAsMetadata(M);
+ EXPECT_MDEQ(MD, TestMD.get(0, DummyGV, "", 2, 0, 1, 2, true, false, false,
+ TestMD.get(0, 4)));
// RasterizerOrderedBuffer<float4> ROB;
- Symbol = UndefValue::get(
- StructType::create(Context, {Floatx4Ty},
- "class.RasterizerOrderedBuffer<vector<float, 4> >"));
- Resource = ResourceInfo::UAV(Symbol, "ROB", dxil::ElementType::F32,
- /*ElementCount=*/4, /*GloballyCoherent=*/0,
- /*IsROV=*/1, dxil::ResourceKind::TypedBuffer);
- Resource.bind(0, 0, 0, 1);
- Props = Resource.getAnnotateProps();
+ HandleTy = llvm::TargetExtType::get(
+ Context, "dx.TypedBuffer", Floatx4Ty,
+ {/*IsWriteable=*/1, /*IsROV=*/1, /*IsSigned=*/0});
+ RI = ResourceInfo(
+ /*RecordID=*/0, /*Space=*/0, /*LowerBound=*/0, /*Size=*/1, HandleTy);
+
+ EXPECT_EQ(RI.getResourceClass(), ResourceClass::UAV);
+ EXPECT_EQ(RI.getUAV().GloballyCoherent, false);
+ EXPECT_EQ(RI.getUAV().HasCounter, false);
+ EXPECT_EQ(RI.getUAV().IsROV, true);
+ ASSERT_EQ(RI.isTyped(), true);
+ EXPECT_EQ(RI.getTyped().ElementTy, ElementType::F32);
+ EXPECT_EQ(RI.getTyped().ElementCount, 4u);
+ EXPECT_EQ(RI.getResourceKind(), ResourceKind::TypedBuffer);
+
+ Props = RI.getAnnotateProps(M);
EXPECT_EQ(Props.first, 0x0000300aU);
EXPECT_EQ(Props.second, 0x00000409U);
- MD = Resource.getAsMetadata(Context);
- EXPECT_MDEQ(MD, TestMD.get(0, Symbol, "ROB", 0, 0, 1, 10, false, false, true,
+ MD = RI.getAsMetadata(M);
+ EXPECT_MDEQ(MD, TestMD.get(0, DummyGV, "", 0, 0, 1, 10, false, false, true,
TestMD.get(0, 9)));
// RWStructuredBuffer<ParticleMotion> g_OutputBuffer : register(u2);
StructType *BufType1 = StructType::create(
Context, {Floatx3Ty, FloatTy, Int32Ty}, "ParticleMotion");
- Symbol = UndefValue::get(StructType::create(
- Context, {BufType1}, "class.StructuredBuffer<ParticleMotion>"));
- Resource =
- ResourceInfo::RWStructuredBuffer(Symbol, "g_OutputBuffer", /*Stride=*/20,
- Align(4), /*GloballyCoherent=*/false,
- /*IsROV=*/false, /*HasCounter=*/true);
- Resource.bind(0, 0, 2, 1);
- Props = Resource.getAnnotateProps();
+ HandleTy = llvm::TargetExtType::get(Context, "dx.RawBuffer", BufType1,
+ {/*IsWriteable=*/1, /*IsROV=*/0});
+ RI = ResourceInfo(
+ /*RecordID=*/0, /*Space=*/0, /*LowerBound=*/2, /*Size=*/1, HandleTy,
+ /*GloballyCoherent=*/false, /*HasCounter=*/true);
+
+ EXPECT_EQ(RI.getResourceClass(), ResourceClass::UAV);
+ EXPECT_EQ(RI.getUAV().GloballyCoherent, false);
+ EXPECT_EQ(RI.getUAV().HasCounter, true);
+ EXPECT_EQ(RI.getUAV().IsROV, false);
+ ASSERT_EQ(RI.isStruct(), true);
+ EXPECT_EQ(RI.getStruct(DL).Stride, 20u);
+ EXPECT_EQ(RI.getStruct(DL).AlignLog2, Log2(Align(4)));
+ EXPECT_EQ(RI.getResourceKind(), ResourceKind::StructuredBuffer);
+
+ Props = RI.getAnnotateProps(M);
EXPECT_EQ(Props.first, 0x0000920cU);
EXPECT_EQ(Props.second, 0x00000014U);
- MD = Resource.getAsMetadata(Context);
- EXPECT_MDEQ(MD, TestMD.get(0, Symbol, "g_OutputBuffer", 0, 2, 1, 12, false,
- true, false, TestMD.get(1, 20)));
-
- // RWTexture2DMSArray<uint,8> g_rw_t2dmsa;
- Symbol = UndefValue::get(StructType::create(
- Context, {Int32Ty}, "class.RWTexture2DMSArray<unsigned int, 8>"));
- Resource = ResourceInfo::RWTexture2DMSArray(
- Symbol, "g_rw_t2dmsa", dxil::ElementType::U32, /*ElementCount=*/1,
- /*SampleCount=*/8, /*GloballyCoherent=*/false);
- Resource.bind(0, 0, 0, 1);
- Props = Resource.getAnnotateProps();
+ MD = RI.getAsMetadata(M);
+ EXPECT_MDEQ(MD, TestMD.get(0, DummyGV, "", 0, 2, 1, 12, false, true, false,
+ TestMD.get(1, 20)));
+
+ // RWTexture2DMSArray<uint, 8> g_rw_t2dmsa;
+ HandleTy = llvm::TargetExtType::get(
+ Context, "dx.MSTexture", Int32Ty,
+ {/*IsWriteable=*/1, /*SampleCount=*/8, /*IsSigned=*/0,
+ llvm::to_underlying(ResourceKind::Texture2DMSArray)});
+ RI = ResourceInfo(
+ /*RecordID=*/0, /*Space=*/0, /*LowerBound=*/0, /*Size=*/1, HandleTy);
+
+ EXPECT_EQ(RI.getResourceClass(), ResourceClass::UAV);
+ EXPECT_EQ(RI.getUAV().GloballyCoherent, false);
+ EXPECT_EQ(RI.getUAV().HasCounter, false);
+ EXPECT_EQ(RI.getUAV().IsROV, false);
+ ASSERT_EQ(RI.isTyped(), true);
+ EXPECT_EQ(RI.getTyped().ElementTy, ElementType::U32);
+ EXPECT_EQ(RI.getTyped().ElementCount, 1u);
+ ASSERT_EQ(RI.isMultiSample(), true);
+ EXPECT_EQ(RI.getMultiSampleCount(), 8u);
+ EXPECT_EQ(RI.getResourceKind(), ResourceKind::Texture2DMSArray);
+
+ Props = RI.getAnnotateProps(M);
EXPECT_EQ(Props.first, 0x00001008U);
EXPECT_EQ(Props.second, 0x00080105U);
- MD = Resource.getAsMetadata(Context);
- EXPECT_MDEQ(MD, TestMD.get(0, Symbol, "g_rw_t2dmsa", 0, 0, 1, 8, false, false,
- false, TestMD.get(0, 5)));
+ MD = RI.getAsMetadata(M);
+ EXPECT_MDEQ(MD, TestMD.get(0, DummyGV, "", 0, 0, 1, 8, false, false, false,
+ TestMD.get(0, 5)));
// cbuffer cb0 { float4 g_X; float4 g_Y; }
- Symbol = UndefValue::get(
- StructType::create(Context, {Floatx4Ty, Floatx4Ty}, "cb0"));
- Resource = ResourceInfo::CBuffer(Symbol, "cb0", /*Size=*/32);
- Resource.bind(0, 0, 0, 1);
- Props = Resource.getAnnotateProps();
+ StructType *CBufType0 =
+ StructType::create(Context, {Floatx4Ty, Floatx4Ty}, "cb0");
+ HandleTy = llvm::TargetExtType::get(Context, "dx.CBuffer", CBufType0, {});
+ RI = ResourceInfo(
+ /*RecordID=*/0, /*Space=*/0, /*LowerBound=*/0, /*Size=*/1, HandleTy);
+
+ EXPECT_EQ(RI.getResourceClass(), ResourceClass::CBuffer);
+ EXPECT_EQ(RI.getCBufferSize(DL), 32u);
+ EXPECT_EQ(RI.getResourceKind(), ResourceKind::CBuffer);
+
+ Props = RI.getAnnotateProps(M);
EXPECT_EQ(Props.first, 0x0000000dU);
EXPECT_EQ(Props.second, 0x00000020U);
- MD = Resource.getAsMetadata(Context);
- EXPECT_MDEQ(MD, TestMD.get(0, Symbol, "cb0", 0, 0, 1, 32, nullptr));
+ MD = RI.getAsMetadata(M);
+ EXPECT_MDEQ(MD, TestMD.get(0, DummyGV, "", 0, 0, 1, 32, nullptr));
// SamplerState ColorMapSampler : register(s0);
- Symbol = UndefValue::get(
- StructType::create(Context, {Int32Ty}, "struct.SamplerState"));
- Resource = ResourceInfo::Sampler(Symbol, "ColorMapSampler",
- dxil::SamplerType::Default);
- Resource.bind(0, 0, 0, 1);
- Props = Resource.getAnnotateProps();
+ HandleTy = llvm::TargetExtType::get(
+ Context, "dx.Sampler", {},
+ {llvm::to_underlying(dxil::SamplerType::Default)});
+ RI = ResourceInfo(
+ /*RecordID=*/0, /*Space=*/0, /*LowerBound=*/0, /*Size=*/1, HandleTy);
+
+ EXPECT_EQ(RI.getResourceClass(), ResourceClass::Sampler);
+ EXPECT_EQ(RI.getSamplerType(), dxil::SamplerType::Default);
+ EXPECT_EQ(RI.getResourceKind(), ResourceKind::Sampler);
+
+ Props = RI.getAnnotateProps(M);
EXPECT_EQ(Props.first, 0x0000000eU);
EXPECT_EQ(Props.second, 0U);
- MD = Resource.getAsMetadata(Context);
- EXPECT_MDEQ(MD,
- TestMD.get(0, Symbol, "ColorMapSampler", 0, 0, 1, 0, nullptr));
+ MD = RI.getAsMetadata(M);
+ EXPECT_MDEQ(MD, TestMD.get(0, DummyGV, "", 0, 0, 1, 0, nullptr));
+
+ HandleTy = llvm::TargetExtType::get(
+ Context, "dx.Sampler", {},
+ {llvm::to_underlying(dxil::SamplerType::Comparison)});
+ RI = ResourceInfo(
+ /*RecordID=*/0, /*Space=*/0, /*LowerBound=*/0, /*Size=*/1, HandleTy);
+
+ EXPECT_EQ(RI.getResourceClass(), ResourceClass::Sampler);
+ EXPECT_EQ(RI.getSamplerType(), dxil::SamplerType::Comparison);
+ EXPECT_EQ(RI.getResourceKind(), ResourceKind::Sampler);
- // SamplerComparisonState ShadowSampler {...};
- Resource = ResourceInfo::Sampler(Symbol, "CmpSampler",
- dxil::SamplerType::Comparison);
- Resource.bind(0, 0, 0, 1);
- Props = Resource.getAnnotateProps();
+ Props = RI.getAnnotateProps(M);
EXPECT_EQ(Props.first, 0x0000800eU);
EXPECT_EQ(Props.second, 0U);
- MD = Resource.getAsMetadata(Context);
- EXPECT_MDEQ(MD, TestMD.get(0, Symbol, "CmpSampler", 0, 0, 1, 1, nullptr));
+ MD = RI.getAsMetadata(M);
+ EXPECT_MDEQ(MD, TestMD.get(0, DummyGV, "", 0, 0, 1, 1, nullptr));
}
>From 90c5fb1eee6532ad9368210193fbb9d8601f912c Mon Sep 17 00:00:00 2001
From: Justin Bogner <mail at justinbogner.com>
Date: Thu, 12 Dec 2024 14:58:29 -0800
Subject: [PATCH 2/3] clang-format
---
llvm/lib/Analysis/DXILResource.cpp | 17 +++++------------
1 file changed, 5 insertions(+), 12 deletions(-)
diff --git a/llvm/lib/Analysis/DXILResource.cpp b/llvm/lib/Analysis/DXILResource.cpp
index f96a9468d6bc54..276f315bbf643b 100644
--- a/llvm/lib/Analysis/DXILResource.cpp
+++ b/llvm/lib/Analysis/DXILResource.cpp
@@ -211,13 +211,9 @@ ResourceInfo::ResourceInfo(uint32_t RecordID, uint32_t Space,
bool ResourceInfo::isUAV() const { return RC == ResourceClass::UAV; }
-bool ResourceInfo::isCBuffer() const {
- return RC == ResourceClass::CBuffer;
-}
+bool ResourceInfo::isCBuffer() const { return RC == ResourceClass::CBuffer; }
-bool ResourceInfo::isSampler() const {
- return RC == ResourceClass::Sampler;
-}
+bool ResourceInfo::isSampler() const { return RC == ResourceClass::Sampler; }
bool ResourceInfo::isStruct() const {
return Kind == ResourceKind::StructuredBuffer;
@@ -309,8 +305,7 @@ dxil::SamplerType ResourceInfo::getSamplerType() const {
return cast<SamplerExtType>(HandleTy)->getSamplerType();
}
-ResourceInfo::StructInfo
-ResourceInfo::getStruct(const DataLayout &DL) const {
+ResourceInfo::StructInfo ResourceInfo::getStruct(const DataLayout &DL) const {
assert(isStruct() && "Not a Struct");
Type *ElTy = cast<RawBufferExtType>(HandleTy)->getResourceType();
@@ -428,8 +423,7 @@ MDTuple *ResourceInfo::getAsMetadata(Module &M) const {
// All SRVs include sample count in the metadata, but it's only meaningful
// for multi-sampled textured. Also, UAVs can be multisampled in SM6.7+,
// but this just isn't reflected in the metadata at all.
- uint32_t SampleCount =
- isMultiSample() ? getMultiSampleCount() : 0;
+ uint32_t SampleCount = isMultiSample() ? getMultiSampleCount() : 0;
MDVals.push_back(getIntMD(SampleCount));
}
@@ -459,8 +453,7 @@ std::pair<uint32_t, uint32_t> ResourceInfo::getAnnotateProps(Module &M) const {
uint32_t ResourceKind = llvm::to_underlying(getResourceKind());
uint32_t AlignLog2 = isStruct() ? getStruct(DL).AlignLog2 : 0;
bool IsUAV = isUAV();
- ResourceInfo::UAVInfo UAVFlags =
- IsUAV ? getUAV() : ResourceInfo::UAVInfo{};
+ ResourceInfo::UAVInfo UAVFlags = IsUAV ? getUAV() : ResourceInfo::UAVInfo{};
bool IsROV = IsUAV && UAVFlags.IsROV;
bool IsGloballyCoherent = IsUAV && UAVFlags.GloballyCoherent;
uint8_t SamplerCmpOrHasCounter = 0;
>From 6333826bc3a9118b26aaf05f8288347b14c30544 Mon Sep 17 00:00:00 2001
From: Justin Bogner <mail at justinbogner.com>
Date: Mon, 16 Dec 2024 11:28:40 -0700
Subject: [PATCH 3/3] fixup: Put cbuffer size in param
---
llvm/include/llvm/Analysis/DXILResource.h | 1 +
llvm/lib/Analysis/DXILResource.cpp | 3 +--
llvm/unittests/Analysis/DXILResourceTest.cpp | 3 ++-
3 files changed, 4 insertions(+), 3 deletions(-)
diff --git a/llvm/include/llvm/Analysis/DXILResource.h b/llvm/include/llvm/Analysis/DXILResource.h
index 0205356af54443..d40bc02d7eec96 100644
--- a/llvm/include/llvm/Analysis/DXILResource.h
+++ b/llvm/include/llvm/Analysis/DXILResource.h
@@ -164,6 +164,7 @@ class CBufferExtType : public TargetExtType {
CBufferExtType &operator=(const CBufferExtType &) = delete;
Type *getResourceType() const { return getTypeParameter(0); }
+ uint32_t getCBufferSize() const { return getIntParameter(0); }
static bool classof(const TargetExtType *T) {
return T->getName() == "dx.CBuffer";
diff --git a/llvm/lib/Analysis/DXILResource.cpp b/llvm/lib/Analysis/DXILResource.cpp
index 276f315bbf643b..3fa9d67488b0cc 100644
--- a/llvm/lib/Analysis/DXILResource.cpp
+++ b/llvm/lib/Analysis/DXILResource.cpp
@@ -296,8 +296,7 @@ ResourceInfo::UAVInfo ResourceInfo::getUAV() const {
uint32_t ResourceInfo::getCBufferSize(const DataLayout &DL) const {
assert(isCBuffer() && "Not a CBuffer");
- Type *Ty = cast<CBufferExtType>(HandleTy)->getResourceType();
- return DL.getTypeSizeInBits(Ty) / 8;
+ return cast<CBufferExtType>(HandleTy)->getCBufferSize();
}
dxil::SamplerType ResourceInfo::getSamplerType() const {
diff --git a/llvm/unittests/Analysis/DXILResourceTest.cpp b/llvm/unittests/Analysis/DXILResourceTest.cpp
index 2122f1a91cfc96..776c914b89a045 100644
--- a/llvm/unittests/Analysis/DXILResourceTest.cpp
+++ b/llvm/unittests/Analysis/DXILResourceTest.cpp
@@ -380,7 +380,8 @@ TEST(DXILResource, AnnotationsAndMetadata) {
// cbuffer cb0 { float4 g_X; float4 g_Y; }
StructType *CBufType0 =
StructType::create(Context, {Floatx4Ty, Floatx4Ty}, "cb0");
- HandleTy = llvm::TargetExtType::get(Context, "dx.CBuffer", CBufType0, {});
+ HandleTy =
+ llvm::TargetExtType::get(Context, "dx.CBuffer", CBufType0, {/*Size=*/32});
RI = ResourceInfo(
/*RecordID=*/0, /*Space=*/0, /*LowerBound=*/0, /*Size=*/1, HandleTy);
More information about the llvm-commits
mailing list