[flang-commits] [libcxx] [flang] [llvm] [lldb] [libcxxabi] [clang] [clang-tools-extra] [lld] [compiler-rt] [libc] [HLSL][DirectX] Move handling of resource element types into the frontend (PR #75674)

Justin Bogner via flang-commits flang-commits at lists.llvm.org
Mon Dec 18 09:52:01 PST 2023


https://github.com/bogner updated https://github.com/llvm/llvm-project/pull/75674

>From 9d6e00bd972a563daefd67b544614e2bb609cc42 Mon Sep 17 00:00:00 2001
From: Justin Bogner <mail at justinbogner.com>
Date: Fri, 15 Dec 2023 16:29:09 -0800
Subject: [PATCH] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20initia?=
 =?UTF-8?q?l=20version?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Created using spr 1.3.5-bogner
---
 clang/lib/CodeGen/CGHLSLRuntime.cpp           | 63 +++++++++++--
 clang/lib/CodeGen/CGHLSLRuntime.h             |  2 +-
 .../builtins/RWBuffer-annotations.hlsl        | 14 +--
 .../builtins/RWBuffer-elementtype.hlsl        | 52 +++++++++++
 .../RasterizerOrderedBuffer-annotations.hlsl  | 12 +--
 clang/test/CodeGenHLSL/cbuf.hlsl              |  4 +-
 .../include/llvm/Frontend/HLSL/HLSLResource.h | 27 +++++-
 llvm/lib/Frontend/HLSL/HLSLResource.cpp       | 17 ++--
 llvm/lib/Target/DirectX/DXILResource.cpp      | 92 ++++++-------------
 llvm/lib/Target/DirectX/DXILResource.h        | 37 ++------
 llvm/test/CodeGen/DirectX/UAVMetadata.ll      | 22 ++---
 11 files changed, 204 insertions(+), 138 deletions(-)
 create mode 100644 clang/test/CodeGenHLSL/builtins/RWBuffer-elementtype.hlsl

diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp
index 3e8a40e7540bef..e887d35198b3c7 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.cpp
+++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp
@@ -182,10 +182,8 @@ void CGHLSLRuntime::finishCodeGen() {
     llvm::hlsl::ResourceKind RK = Buf.IsCBuffer
                                       ? llvm::hlsl::ResourceKind::CBuffer
                                       : llvm::hlsl::ResourceKind::TBuffer;
-    std::string TyName =
-        Buf.Name.str() + (Buf.IsCBuffer ? ".cb." : ".tb.") + "ty";
-    addBufferResourceAnnotation(GV, TyName, RC, RK, /*IsROV=*/false,
-                                Buf.Binding);
+    addBufferResourceAnnotation(GV, RC, RK, /*IsROV=*/false,
+                                llvm::hlsl::ElementType::Invalid, Buf.Binding);
   }
 }
 
@@ -194,10 +192,10 @@ CGHLSLRuntime::Buffer::Buffer(const HLSLBufferDecl *D)
       Binding(D->getAttr<HLSLResourceBindingAttr>()) {}
 
 void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable *GV,
-                                                llvm::StringRef TyName,
                                                 llvm::hlsl::ResourceClass RC,
                                                 llvm::hlsl::ResourceKind RK,
                                                 bool IsROV,
+                                                llvm::hlsl::ElementType ET,
                                                 BufferResBinding &Binding) {
   llvm::Module &M = CGM.getModule();
 
@@ -216,15 +214,62 @@ void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable *GV,
     assert(false && "Unsupported buffer type!");
     return;
   }
-
   assert(ResourceMD != nullptr &&
          "ResourceMD must have been set by the switch above.");
 
   llvm::hlsl::FrontendResource Res(
-      GV, TyName, RK, IsROV, Binding.Reg.value_or(UINT_MAX), Binding.Space);
+      GV, RK, ET, IsROV, Binding.Reg.value_or(UINT_MAX), Binding.Space);
   ResourceMD->addOperand(Res.getMetadata());
 }
 
+static llvm::hlsl::ElementType
+calculateElementType(const ASTContext &Context, const clang::Type *ResourceTy) {
+  using llvm::hlsl::ElementType;
+
+  // TODO: We may need to update this when we add things like ByteAddressBuffer
+  // that don't have a template parameter (or, indeed, an element type).
+  const auto *TST = ResourceTy->getAs<TemplateSpecializationType>();
+  assert(TST && "Resource types must be template specializations");
+  ArrayRef<TemplateArgument> Args = TST->template_arguments();
+  assert(!Args.empty() && "Resource has no element type");
+
+  // At this point we have a resource with an element type, so we can assume
+  // that it's valid or we would have diagnosed the error earlier.
+  QualType ElTy = Args[0].getAsType();
+
+  // We should either have a basic type or a vector of a basic type.
+  if (const auto *VecTy = ElTy->getAs<clang::VectorType>())
+    ElTy = VecTy->getElementType();
+
+  if (ElTy->isSignedIntegerType()) {
+    switch (Context.getTypeSize(ElTy)) {
+    case 16:
+      return ElementType::I16;
+    case 32:
+      return ElementType::I32;
+    case 64:
+      return ElementType::I64;
+    }
+  } else if (ElTy->isUnsignedIntegerType()) {
+    switch (Context.getTypeSize(ElTy)) {
+    case 16:
+      return ElementType::U16;
+    case 32:
+      return ElementType::U32;
+    case 64:
+      return ElementType::U64;
+    }
+  } else if (ElTy->isSpecificBuiltinType(BuiltinType::Half))
+    return ElementType::F16;
+  else if (ElTy->isSpecificBuiltinType(BuiltinType::Float))
+    return ElementType::F32;
+  else if (ElTy->isSpecificBuiltinType(BuiltinType::Double))
+    return ElementType::F64;
+
+  // TODO: We need to handle unorm/snorm float types here once we support them
+  llvm_unreachable("Invalid element type for resource");
+}
+
 void CGHLSLRuntime::annotateHLSLResource(const VarDecl *D, GlobalVariable *GV) {
   const Type *Ty = D->getType()->getPointeeOrArrayElementType();
   if (!Ty)
@@ -239,10 +284,10 @@ void CGHLSLRuntime::annotateHLSLResource(const VarDecl *D, GlobalVariable *GV) {
   llvm::hlsl::ResourceClass RC = Attr->getResourceClass();
   llvm::hlsl::ResourceKind RK = Attr->getResourceKind();
   bool IsROV = Attr->getIsROV();
+  llvm::hlsl::ElementType ET = calculateElementType(CGM.getContext(), Ty);
 
-  QualType QT(Ty, 0);
   BufferResBinding Binding(D->getAttr<HLSLResourceBindingAttr>());
-  addBufferResourceAnnotation(GV, QT.getAsString(), RC, RK, IsROV, Binding);
+  addBufferResourceAnnotation(GV, RC, RK, IsROV, ET, Binding);
 }
 
 CGHLSLRuntime::BufferResBinding::BufferResBinding(
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index bb500cb5c979f2..bffefb66740a00 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -90,9 +90,9 @@ class CGHLSLRuntime {
 
 private:
   void addBufferResourceAnnotation(llvm::GlobalVariable *GV,
-                                   llvm::StringRef TyName,
                                    llvm::hlsl::ResourceClass RC,
                                    llvm::hlsl::ResourceKind RK, bool IsROV,
+                                   llvm::hlsl::ElementType ET,
                                    BufferResBinding &Binding);
   void addConstant(VarDecl *D, Buffer &CB);
   void addBufferDecls(const DeclContext *DC, Buffer &CB);
diff --git a/clang/test/CodeGenHLSL/builtins/RWBuffer-annotations.hlsl b/clang/test/CodeGenHLSL/builtins/RWBuffer-annotations.hlsl
index a70e224b81e4b7..7ca78e60fb9c59 100644
--- a/clang/test/CodeGenHLSL/builtins/RWBuffer-annotations.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/RWBuffer-annotations.hlsl
@@ -1,4 +1,4 @@
-// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s 
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s
 
 RWBuffer<float> Buffer1;
 RWBuffer<vector<float, 4> > BufferArray[4];
@@ -16,9 +16,9 @@ void main() {
 }
 
 // CHECK: !hlsl.uavs = !{![[Single:[0-9]+]], ![[Array:[0-9]+]], ![[SingleAllocated:[0-9]+]], ![[ArrayAllocated:[0-9]+]], ![[SingleSpace:[0-9]+]], ![[ArraySpace:[0-9]+]]}
-// CHECK-DAG: ![[Single]] = !{ptr @"?Buffer1@@3V?$RWBuffer at M@hlsl@@A", !"RWBuffer<float>", i32 10, i1 false, i32 -1, i32 0}
-// CHECK-DAG: ![[Array]] = !{ptr @"?BufferArray@@3PAV?$RWBuffer at T?$__vector at M$03 at __clang@@@hlsl@@A", !"RWBuffer<vector<float, 4> >", i32 10, i1 false, i32 -1, i32 0}
-// CHECK-DAG: ![[SingleAllocated]] = !{ptr @"?Buffer2@@3V?$RWBuffer at M@hlsl@@A", !"RWBuffer<float>", i32 10, i1 false, i32 3, i32 0}
-// CHECK-DAG: ![[ArrayAllocated]] = !{ptr @"?BufferArray2@@3PAV?$RWBuffer at T?$__vector at M$03 at __clang@@@hlsl@@A", !"RWBuffer<vector<float, 4> >", i32 10, i1 false, i32 4, i32 0}
-// CHECK-DAG: ![[SingleSpace]] = !{ptr @"?Buffer3@@3V?$RWBuffer at M@hlsl@@A", !"RWBuffer<float>", i32 10, i1 false, i32 3, i32 1}
-// CHECK-DAG: ![[ArraySpace]] = !{ptr @"?BufferArray3@@3PAV?$RWBuffer at T?$__vector at M$03 at __clang@@@hlsl@@A", !"RWBuffer<vector<float, 4> >", i32 10, i1 false, i32 4, i32 1}
+// CHECK-DAG: ![[Single]] = !{ptr @"?Buffer1@@3V?$RWBuffer at M@hlsl@@A", i32 10, i32 9, i1 false, i32 -1, i32 0}
+// CHECK-DAG: ![[Array]] = !{ptr @"?BufferArray@@3PAV?$RWBuffer at T?$__vector at M$03 at __clang@@@hlsl@@A", i32 10, i32 9, i1 false, i32 -1, i32 0}
+// CHECK-DAG: ![[SingleAllocated]] = !{ptr @"?Buffer2@@3V?$RWBuffer at M@hlsl@@A", i32 10, i32 9, i1 false, i32 3, i32 0}
+// CHECK-DAG: ![[ArrayAllocated]] = !{ptr @"?BufferArray2@@3PAV?$RWBuffer at T?$__vector at M$03 at __clang@@@hlsl@@A", i32 10, i32 9, i1 false, i32 4, i32 0}
+// CHECK-DAG: ![[SingleSpace]] = !{ptr @"?Buffer3@@3V?$RWBuffer at M@hlsl@@A", i32 10, i32 9, i1 false, i32 3, i32 1}
+// CHECK-DAG: ![[ArraySpace]] = !{ptr @"?BufferArray3@@3PAV?$RWBuffer at T?$__vector at M$03 at __clang@@@hlsl@@A", i32 10, i32 9, i1 false, i32 4, i32 1}
diff --git a/clang/test/CodeGenHLSL/builtins/RWBuffer-elementtype.hlsl b/clang/test/CodeGenHLSL/builtins/RWBuffer-elementtype.hlsl
new file mode 100644
index 00000000000000..87002ccd462d3f
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/RWBuffer-elementtype.hlsl
@@ -0,0 +1,52 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -finclude-default-header -fnative-half-type -emit-llvm -o - %s | FileCheck %s
+
+RWBuffer<int16_t> BufI16;
+RWBuffer<uint16_t> BufU16;
+RWBuffer<int> BufI32;
+RWBuffer<uint> BufU32;
+RWBuffer<int64_t> BufI64;
+RWBuffer<uint64_t> BufU64;
+RWBuffer<half> BufF16;
+RWBuffer<float> BufF32;
+RWBuffer<double> BufF64;
+RWBuffer< vector<int16_t, 4> > BufI16x4;
+RWBuffer< vector<uint, 3> > BufU32x3;
+RWBuffer<half2> BufF16x2;
+RWBuffer<float3> BufF32x3;
+// TODO: RWBuffer<snorm half> BufSNormF16; -> 11
+// TODO: RWBuffer<unorm half> BufUNormF16; -> 12
+// TODO: RWBuffer<snorm float> BufSNormF32; -> 13
+// TODO: RWBuffer<unorm float> BufUNormF32; -> 14
+// TODO: RWBuffer<snorm double> BufSNormF64; -> 15
+// TODO: RWBuffer<unorm double> BufUNormF64; -> 16
+
+[numthreads(1,1,1)]
+void main(int GI : SV_GroupIndex) {
+  BufI16[GI] = 0;
+  BufU16[GI] = 0;
+  BufI32[GI] = 0;
+  BufU32[GI] = 0;
+  BufI64[GI] = 0;
+  BufU64[GI] = 0;
+  BufF16[GI] = 0;
+  BufF32[GI] = 0;
+  BufF64[GI] = 0;
+  BufI16x4[GI] = 0;
+  BufU32x3[GI] = 0;
+  BufF16x2[GI] = 0;
+  BufF32x3[GI] = 0;
+}
+
+// CHECK: !{{[0-9]+}} = !{ptr @"?BufI16@@3V?$RWBuffer at F@hlsl@@A", i32 10, i32 2,
+// CHECK: !{{[0-9]+}} = !{ptr @"?BufU16@@3V?$RWBuffer at G@hlsl@@A", i32 10, i32 3,
+// CHECK: !{{[0-9]+}} = !{ptr @"?BufI32@@3V?$RWBuffer at H@hlsl@@A", i32 10, i32 4,
+// CHECK: !{{[0-9]+}} = !{ptr @"?BufU32@@3V?$RWBuffer at I@hlsl@@A", i32 10, i32 5,
+// CHECK: !{{[0-9]+}} = !{ptr @"?BufI64@@3V?$RWBuffer at J@hlsl@@A", i32 10, i32 6,
+// CHECK: !{{[0-9]+}} = !{ptr @"?BufU64@@3V?$RWBuffer at K@hlsl@@A", i32 10, i32 7,
+// CHECK: !{{[0-9]+}} = !{ptr @"?BufF16@@3V?$RWBuffer@$f16@@hlsl@@A", i32 10, i32 8,
+// CHECK: !{{[0-9]+}} = !{ptr @"?BufF32@@3V?$RWBuffer at M@hlsl@@A", i32 10, i32 9,
+// CHECK: !{{[0-9]+}} = !{ptr @"?BufF64@@3V?$RWBuffer at N@hlsl@@A", i32 10, i32 10,
+// CHECK: !{{[0-9]+}} = !{ptr @"?BufI16x4@@3V?$RWBuffer at T?$__vector at F$03 at __clang@@@hlsl@@A", i32 10, i32 2,
+// CHECK: !{{[0-9]+}} = !{ptr @"?BufU32x3@@3V?$RWBuffer at T?$__vector at I$02 at __clang@@@hlsl@@A", i32 10, i32 5,
+// CHECK: !{{[0-9]+}} = !{ptr @"?BufF16x2@@3V?$RWBuffer at T?$__vector@$f16@$01 at __clang@@@hlsl@@A", i32 10, i32 8,
+// CHECK: !{{[0-9]+}} = !{ptr @"?BufF32x3@@3V?$RWBuffer at T?$__vector at M$02 at __clang@@@hlsl@@A", i32 10, i32 9,
diff --git a/clang/test/CodeGenHLSL/builtins/RasterizerOrderedBuffer-annotations.hlsl b/clang/test/CodeGenHLSL/builtins/RasterizerOrderedBuffer-annotations.hlsl
index ce7d84ecf5b147..bf70cc2456c8bc 100644
--- a/clang/test/CodeGenHLSL/builtins/RasterizerOrderedBuffer-annotations.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/RasterizerOrderedBuffer-annotations.hlsl
@@ -12,9 +12,9 @@ RasterizerOrderedBuffer<vector<float, 4> > BufferArray3[4] : register(u4, space1
 void main() {}
 
 // CHECK: !hlsl.uavs = !{![[Single:[0-9]+]], ![[Array:[0-9]+]], ![[SingleAllocated:[0-9]+]], ![[ArrayAllocated:[0-9]+]], ![[SingleSpace:[0-9]+]], ![[ArraySpace:[0-9]+]]}
-// CHECK-DAG: ![[Single]] = !{ptr @"?Buffer1@@3V?$RasterizerOrderedBuffer at M@hlsl@@A", !"RasterizerOrderedBuffer<float>", i32 10, i1 true, i32 -1, i32 0}
-// CHECK-DAG: ![[Array]] = !{ptr @"?BufferArray@@3PAV?$RasterizerOrderedBuffer at T?$__vector at M$03 at __clang@@@hlsl@@A", !"RasterizerOrderedBuffer<vector<float, 4> >", i32 10, i1 true, i32 -1, i32 0}
-// CHECK-DAG: ![[SingleAllocated]] = !{ptr @"?Buffer2@@3V?$RasterizerOrderedBuffer at M@hlsl@@A", !"RasterizerOrderedBuffer<float>", i32 10, i1 true, i32 3, i32 0}
-// CHECK-DAG: ![[ArrayAllocated]] = !{ptr @"?BufferArray2@@3PAV?$RasterizerOrderedBuffer at T?$__vector at M$03 at __clang@@@hlsl@@A", !"RasterizerOrderedBuffer<vector<float, 4> >", i32 10, i1 true, i32 4, i32 0}
-// CHECK-DAG: ![[SingleSpace]] = !{ptr @"?Buffer3@@3V?$RasterizerOrderedBuffer at M@hlsl@@A", !"RasterizerOrderedBuffer<float>", i32 10, i1 true, i32 3, i32 1}
-// CHECK-DAG: ![[ArraySpace]] = !{ptr @"?BufferArray3@@3PAV?$RasterizerOrderedBuffer at T?$__vector at M$03 at __clang@@@hlsl@@A", !"RasterizerOrderedBuffer<vector<float, 4> >", i32 10, i1 true, i32 4, i32 1}
+// CHECK-DAG: ![[Single]] = !{ptr @"?Buffer1@@3V?$RasterizerOrderedBuffer at M@hlsl@@A", i32 10, i32 9, i1 true, i32 -1, i32 0}
+// CHECK-DAG: ![[Array]] = !{ptr @"?BufferArray@@3PAV?$RasterizerOrderedBuffer at T?$__vector at M$03 at __clang@@@hlsl@@A", i32 10, i32 9, i1 true, i32 -1, i32 0}
+// CHECK-DAG: ![[SingleAllocated]] = !{ptr @"?Buffer2@@3V?$RasterizerOrderedBuffer at M@hlsl@@A", i32 10, i32 9, i1 true, i32 3, i32 0}
+// CHECK-DAG: ![[ArrayAllocated]] = !{ptr @"?BufferArray2@@3PAV?$RasterizerOrderedBuffer at T?$__vector at M$03 at __clang@@@hlsl@@A", i32 10, i32 9, i1 true, i32 4, i32 0}
+// CHECK-DAG: ![[SingleSpace]] = !{ptr @"?Buffer3@@3V?$RasterizerOrderedBuffer at M@hlsl@@A", i32 10, i32 9, i1 true, i32 3, i32 1}
+// CHECK-DAG: ![[ArraySpace]] = !{ptr @"?BufferArray3@@3PAV?$RasterizerOrderedBuffer at T?$__vector at M$03 at __clang@@@hlsl@@A", i32 10, i32 9, i1 true, i32 4, i32 1}
diff --git a/clang/test/CodeGenHLSL/cbuf.hlsl b/clang/test/CodeGenHLSL/cbuf.hlsl
index 5dee1feb902aa0..dc2a6aaa8f4335 100644
--- a/clang/test/CodeGenHLSL/cbuf.hlsl
+++ b/clang/test/CodeGenHLSL/cbuf.hlsl
@@ -24,5 +24,5 @@ float foo() {
 
 // CHECK: !hlsl.cbufs = !{![[CBMD:[0-9]+]]}
 // CHECK: !hlsl.srvs = !{![[TBMD:[0-9]+]]}
-// CHECK: ![[CBMD]] = !{ptr @[[CB]], !"A.cb.ty", i32 13, i1 false, i32 0, i32 2}
-// CHECK: ![[TBMD]] = !{ptr @[[TB]], !"A.tb.ty", i32 15, i1 false, i32 2, i32 1}
+// CHECK: ![[CBMD]] = !{ptr @[[CB]], i32 13, i32 0, i1 false, i32 0, i32 2}
+// CHECK: ![[TBMD]] = !{ptr @[[TB]], i32 15, i32 0, i1 false, i32 2, i32 1}
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLResource.h b/llvm/include/llvm/Frontend/HLSL/HLSLResource.h
index eedecaea4e58da..068b4c66711e46 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLResource.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLResource.h
@@ -54,6 +54,30 @@ enum class ResourceKind : uint32_t {
   NumEntries,
 };
 
+// The value ordering of this enumeration is part of the DXIL ABI. Elements
+// can only be added to the end, and not removed.
+enum class ElementType : uint32_t {
+  Invalid = 0,
+  I1,
+  I16,
+  U16,
+  I32,
+  U32,
+  I64,
+  U64,
+  F16,
+  F32,
+  F64,
+  SNormF16,
+  UNormF16,
+  SNormF32,
+  UNormF32,
+  SNormF64,
+  UNormF64,
+  PackedS8x32,
+  PackedU8x32,
+};
+
 class FrontendResource {
   MDNode *Entry;
 
@@ -62,12 +86,13 @@ class FrontendResource {
     assert(Entry->getNumOperands() == 6 && "Unexpected metadata shape");
   }
 
-  FrontendResource(GlobalVariable *GV, StringRef TypeStr, ResourceKind RK,
+  FrontendResource(GlobalVariable *GV, ResourceKind RK, ElementType ElTy,
                    bool IsROV, uint32_t ResIndex, uint32_t Space);
 
   GlobalVariable *getGlobalVariable();
   StringRef getSourceType();
   ResourceKind getResourceKind();
+  ElementType getElementType();
   bool getIsROV();
   uint32_t getResourceIndex();
   uint32_t getSpace();
diff --git a/llvm/lib/Frontend/HLSL/HLSLResource.cpp b/llvm/lib/Frontend/HLSL/HLSLResource.cpp
index 709fe3212623ef..bcdbe5eadc69e5 100644
--- a/llvm/lib/Frontend/HLSL/HLSLResource.cpp
+++ b/llvm/lib/Frontend/HLSL/HLSLResource.cpp
@@ -23,12 +23,14 @@ GlobalVariable *FrontendResource::getGlobalVariable() {
       cast<ConstantAsMetadata>(Entry->getOperand(0))->getValue());
 }
 
-StringRef FrontendResource::getSourceType() {
-  return cast<MDString>(Entry->getOperand(1))->getString();
-}
-
 ResourceKind FrontendResource::getResourceKind() {
   return static_cast<ResourceKind>(
+      cast<ConstantInt>(
+          cast<ConstantAsMetadata>(Entry->getOperand(1))->getValue())
+          ->getLimitedValue());
+}
+ElementType FrontendResource::getElementType() {
+  return static_cast<ElementType>(
       cast<ConstantInt>(
           cast<ConstantAsMetadata>(Entry->getOperand(2))->getValue())
           ->getLimitedValue());
@@ -49,14 +51,15 @@ uint32_t FrontendResource::getSpace() {
       ->getLimitedValue();
 }
 
-FrontendResource::FrontendResource(GlobalVariable *GV, StringRef TypeStr,
-                                   ResourceKind RK, bool IsROV,
+FrontendResource::FrontendResource(GlobalVariable *GV, ResourceKind RK,
+                                   ElementType ElTy, bool IsROV,
                                    uint32_t ResIndex, uint32_t Space) {
   auto &Ctx = GV->getContext();
   IRBuilder<> B(Ctx);
   Entry = MDNode::get(
-      Ctx, {ValueAsMetadata::get(GV), MDString::get(Ctx, TypeStr),
+      Ctx, {ValueAsMetadata::get(GV),
             ConstantAsMetadata::get(B.getInt32(static_cast<int>(RK))),
+            ConstantAsMetadata::get(B.getInt32(static_cast<int>(ElTy))),
             ConstantAsMetadata::get(B.getInt1(IsROV)),
             ConstantAsMetadata::get(B.getInt32(ResIndex)),
             ConstantAsMetadata::get(B.getInt32(Space))});
diff --git a/llvm/lib/Target/DirectX/DXILResource.cpp b/llvm/lib/Target/DirectX/DXILResource.cpp
index 92306d907e0546..b22f6d3ca4cd55 100644
--- a/llvm/lib/Target/DirectX/DXILResource.cpp
+++ b/llvm/lib/Target/DirectX/DXILResource.cpp
@@ -63,57 +63,56 @@ ResourceBase::ResourceBase(uint32_t I, FrontendResource R)
     RangeSize = ArrTy->getNumElements();
 }
 
-StringRef ResourceBase::getComponentTypeName(ComponentType CompType) {
-  switch (CompType) {
-  case ComponentType::LastEntry:
-  case ComponentType::Invalid:
+StringRef ResourceBase::getElementTypeName(ElementType ElTy) {
+  switch (ElTy) {
+  case ElementType::Invalid:
     return "invalid";
-  case ComponentType::I1:
+  case ElementType::I1:
     return "i1";
-  case ComponentType::I16:
+  case ElementType::I16:
     return "i16";
-  case ComponentType::U16:
+  case ElementType::U16:
     return "u16";
-  case ComponentType::I32:
+  case ElementType::I32:
     return "i32";
-  case ComponentType::U32:
+  case ElementType::U32:
     return "u32";
-  case ComponentType::I64:
+  case ElementType::I64:
     return "i64";
-  case ComponentType::U64:
+  case ElementType::U64:
     return "u64";
-  case ComponentType::F16:
+  case ElementType::F16:
     return "f16";
-  case ComponentType::F32:
+  case ElementType::F32:
     return "f32";
-  case ComponentType::F64:
+  case ElementType::F64:
     return "f64";
-  case ComponentType::SNormF16:
+  case ElementType::SNormF16:
     return "snorm_f16";
-  case ComponentType::UNormF16:
+  case ElementType::UNormF16:
     return "unorm_f16";
-  case ComponentType::SNormF32:
+  case ElementType::SNormF32:
     return "snorm_f32";
-  case ComponentType::UNormF32:
+  case ElementType::UNormF32:
     return "unorm_f32";
-  case ComponentType::SNormF64:
+  case ElementType::SNormF64:
     return "snorm_f64";
-  case ComponentType::UNormF64:
+  case ElementType::UNormF64:
     return "unorm_f64";
-  case ComponentType::PackedS8x32:
+  case ElementType::PackedS8x32:
     return "p32i8";
-  case ComponentType::PackedU8x32:
+  case ElementType::PackedU8x32:
     return "p32u8";
   }
-  llvm_unreachable("All ComponentType enums are handled in switch");
+  llvm_unreachable("All ElementType enums are handled in switch");
 }
 
-void ResourceBase::printComponentType(Kinds Kind, ComponentType CompType,
-                                      unsigned Alignment, raw_ostream &OS) {
+void ResourceBase::printElementType(Kinds Kind, ElementType ElTy,
+                                    unsigned Alignment, raw_ostream &OS) {
   switch (Kind) {
   default:
     // TODO: add vector size.
-    OS << right_justify(getComponentTypeName(CompType), Alignment);
+    OS << right_justify(getElementTypeName(ElTy), Alignment);
     break;
   case Kinds::RawBuffer:
     OS << right_justify("byte", Alignment);
@@ -232,19 +231,13 @@ void ResourceBase::print(raw_ostream &OS, StringRef IDPrefix,
     OS << right_justify("unbounded", 6) << "\n";
 }
 
-UAVResource::UAVResource(uint32_t I, FrontendResource R)
-    : ResourceBase(I, R), Shape(R.getResourceKind()), GloballyCoherent(false),
-      HasCounter(false), IsROV(R.getIsROV()), ExtProps() {
-  parseSourceType(R.getSourceType());
-}
-
 void UAVResource::print(raw_ostream &OS) const {
   OS << "; " << left_justify(Name, 31);
 
   OS << right_justify("UAV", 10);
 
-  printComponentType(
-      Shape, ExtProps.ElementType.value_or(ComponentType::Invalid), 8, OS);
+  printElementType(
+      Shape, ExtProps.ElementType.value_or(ElementType::Invalid), 8, OS);
 
   // FIXME: support SampleCount.
   // See https://github.com/llvm/llvm-project/issues/58175
@@ -253,35 +246,6 @@ void UAVResource::print(raw_ostream &OS) const {
   ResourceBase::print(OS, "U", "u");
 }
 
-// FIXME: Capture this in HLSL source. I would go do this right now, but I want
-// to get this in first so that I can make sure to capture all the extra
-// information we need to remove the source type string from here (See issue:
-// https://github.com/llvm/llvm-project/issues/57991).
-void UAVResource::parseSourceType(StringRef S) {
-  S = S.substr(S.find("<") + 1);
-
-  constexpr size_t PrefixLen = StringRef("vector<").size();
-  if (S.startswith("vector<"))
-    S = S.substr(PrefixLen, S.find(",") - PrefixLen);
-  else
-    S = S.substr(0, S.find(">"));
-
-  ComponentType ElTy = StringSwitch<ResourceBase::ComponentType>(S)
-                           .Case("bool", ComponentType::I1)
-                           .Case("int16_t", ComponentType::I16)
-                           .Case("uint16_t", ComponentType::U16)
-                           .Case("int32_t", ComponentType::I32)
-                           .Case("uint32_t", ComponentType::U32)
-                           .Case("int64_t", ComponentType::I64)
-                           .Case("uint64_t", ComponentType::U64)
-                           .Case("half", ComponentType::F16)
-                           .Case("float", ComponentType::F32)
-                           .Case("double", ComponentType::F64)
-                           .Default(ComponentType::Invalid);
-  if (ElTy != ComponentType::Invalid)
-    ExtProps.ElementType = ElTy;
-}
-
 ConstantBuffer::ConstantBuffer(uint32_t I, hlsl::FrontendResource R)
     : ResourceBase(I, R) {}
 
@@ -294,7 +258,7 @@ void ConstantBuffer::print(raw_ostream &OS) const {
 
   OS << right_justify("cbuffer", 10);
 
-  printComponentType(Kinds::CBuffer, ComponentType::Invalid, 8, OS);
+  printElementType(Kinds::CBuffer, ElementType::Invalid, 8, OS);
 
   printKind(Kinds::CBuffer, 12, OS, /*SRV*/ false, /*HasCounter*/ false);
   // Print the binding part.
diff --git a/llvm/lib/Target/DirectX/DXILResource.h b/llvm/lib/Target/DirectX/DXILResource.h
index cb39020bc61eb9..5f8b0badd145c1 100644
--- a/llvm/lib/Target/DirectX/DXILResource.h
+++ b/llvm/lib/Target/DirectX/DXILResource.h
@@ -46,38 +46,13 @@ class ResourceBase {
                         bool SRV = false, bool HasCounter = false,
                         uint32_t SampleCount = 0);
 
-  // The value ordering of this enumeration is part of the DXIL ABI. Elements
-  // can only be added to the end, and not removed.
-  enum class ComponentType : uint32_t {
-    Invalid = 0,
-    I1,
-    I16,
-    U16,
-    I32,
-    U32,
-    I64,
-    U64,
-    F16,
-    F32,
-    F64,
-    SNormF16,
-    UNormF16,
-    SNormF32,
-    UNormF32,
-    SNormF64,
-    UNormF64,
-    PackedS8x32,
-    PackedU8x32,
-    LastEntry
-  };
-
-  static StringRef getComponentTypeName(ComponentType CompType);
-  static void printComponentType(Kinds Kind, ComponentType CompType,
-                                 unsigned Alignment, raw_ostream &OS);
+  static StringRef getElementTypeName(hlsl::ElementType CompType);
+  static void printElementType(Kinds Kind, hlsl::ElementType CompType,
+                               unsigned Alignment, raw_ostream &OS);
 
 public:
   struct ExtendedProperties {
-    std::optional<ComponentType> ElementType;
+    std::optional<hlsl::ElementType> ElementType;
 
     // The value ordering of this enumeration is part of the DXIL ABI. Elements
     // can only be added to the end, and not removed.
@@ -102,7 +77,9 @@ class UAVResource : public ResourceBase {
   void parseSourceType(StringRef S);
 
 public:
-  UAVResource(uint32_t I, hlsl::FrontendResource R);
+  UAVResource(uint32_t I, hlsl::FrontendResource R)
+      : ResourceBase(I, R), Shape(R.getResourceKind()), GloballyCoherent(false),
+        HasCounter(false), IsROV(R.getIsROV()), ExtProps{R.getElementType()} {}
 
   MDNode *write() const;
   void print(raw_ostream &O) const;
diff --git a/llvm/test/CodeGen/DirectX/UAVMetadata.ll b/llvm/test/CodeGen/DirectX/UAVMetadata.ll
index 3d95723d6e49f0..0bc8a8cfcd713b 100644
--- a/llvm/test/CodeGen/DirectX/UAVMetadata.ll
+++ b/llvm/test/CodeGen/DirectX/UAVMetadata.ll
@@ -37,22 +37,22 @@ target triple = "dxil-pc-shadermodel6.0-library"
 
 !hlsl.uavs = !{!0, !1, !2, !3, !4, !5, !6, !7, !8, !9}
 
-!0 = !{ptr @Zero, !"RWBuffer<half>", i32 10, i1 false, i32 0, i32 0}
-!1 = !{ptr @One, !"Buffer<vector<float,4>>", i32 10, i1 false, i32 1, i32 0}
-!2 = !{ptr @Two, !"Buffer<double>", i32 10, i1 false, i32 2, i32 0}
-!3 = !{ptr @Three, !"Buffer<bool>", i32 10, i1 false, i32 3, i32 0}
-!4 = !{ptr @Four, !"ByteAddressBuffer<int16_t>", i32 11, i1 false, i32 5, i32 0}
-!5 = !{ptr @Five, !"StructuredBuffer<uint16_t>", i32 12, i1 false, i32 6, i32 0}
-!6 = !{ptr @Six, !"RasterizerOrderedBuffer<int32_t>", i32 10, i1 true, i32 7, i32 0}
-!7 = !{ptr @Seven, !"RasterizerOrderedStructuredBuffer<uint32_t>", i32 12, i1 true, i32 8, i32 0}
-!8 = !{ptr @Eight, !"RasterizerOrderedByteAddressBuffer<int64_t>", i32 11, i1 true, i32 9, i32 0}
-!9 = !{ptr @Nine, !"RWBuffer<uint64_t>", i32 10, i1 false, i32 10, i32 2}
+!0 = !{ptr @Zero, i32 10, i32 8, i1 false, i32 0, i32 0}
+!1 = !{ptr @One, i32 10, i32 9, i1 false, i32 1, i32 0}
+!2 = !{ptr @Two, i32 10, i32 10, i1 false, i32 2, i32 0}
+!3 = !{ptr @Three, i32 10, i32 1, i1 false, i32 3, i32 0}
+!4 = !{ptr @Four, i32 11, i32 2, i1 false, i32 5, i32 0}
+!5 = !{ptr @Five, i32 12, i32 3, i1 false, i32 6, i32 0}
+!6 = !{ptr @Six, i32 10, i32 4, i1 true, i32 7, i32 0}
+!7 = !{ptr @Seven, i32 12, i32 5, i1 true, i32 8, i32 0}
+!8 = !{ptr @Eight, i32 11, i32 6, i1 true, i32 9, i32 0}
+!9 = !{ptr @Nine, i32 10, i32 7, i1 false, i32 10, i32 2}
 
 ; CHECK: !dx.resources = !{[[ResList:[!][0-9]+]]}
 
 ; CHECK: [[ResList]] = !{null, [[UAVList:[!][0-9]+]], null, null}
 ; CHECK: [[UAVList]] = !{[[Zero:[!][0-9]+]], [[One:[!][0-9]+]],
-; CHECK-SAME: [[Two:[!][0-9]+]], [[Three:[!][0-9]+]], [[Four:[!][0-9]+]], 
+; CHECK-SAME: [[Two:[!][0-9]+]], [[Three:[!][0-9]+]], [[Four:[!][0-9]+]],
 ; CHECK-SAME: [[Five:[!][0-9]+]], [[Six:[!][0-9]+]], [[Seven:[!][0-9]+]],
 ; CHECK-SAME: [[Eight:[!][0-9]+]], [[Nine:[!][0-9]+]]}
 ; CHECK: [[Zero]] = !{i32 0, ptr @Zero, !"", i32 0, i32 0, i32 1, i32 10, i1 false, i1 false, i1 false, [[Half:[!][0-9]+]]}



More information about the flang-commits mailing list