[clang] [llvm] [SPIR-V] DRAFT: ext_builtin_input/ext_builtin_output (PR #115187)
Nathan Gauër via llvm-commits
llvm-commits at lists.llvm.org
Fri Nov 8 06:34:58 PST 2024
https://github.com/Keenuts updated https://github.com/llvm/llvm-project/pull/115187
>From 357f8e613e030967f6a95ccbeffe03c0f5f8c186 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Wed, 6 Nov 2024 17:48:13 +0100
Subject: [PATCH] [SPIR-V] DRAFT: ext_builtin_input/ext_builtin_output
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
do not submit
Signed-off-by: Nathan Gauër <brioche at google.com>
---
clang/include/clang/Basic/AddressSpaces.h | 4 ++
clang/include/clang/Basic/Attr.td | 33 ++++++++++
clang/include/clang/Basic/AttrDocs.td | 31 +++++++++
clang/include/clang/Sema/ParsedAttr.h | 11 ++++
clang/include/clang/Sema/SemaHLSL.h | 2 +
clang/lib/AST/TypePrinter.cpp | 4 ++
clang/lib/Basic/TargetInfo.cpp | 2 +
clang/lib/Basic/Targets/AArch64.h | 2 +
clang/lib/Basic/Targets/AMDGPU.cpp | 4 ++
clang/lib/Basic/Targets/DirectX.h | 2 +
clang/lib/Basic/Targets/NVPTX.h | 2 +
clang/lib/Basic/Targets/SPIR.h | 4 ++
clang/lib/Basic/Targets/SystemZ.h | 2 +
clang/lib/Basic/Targets/TCE.h | 2 +
clang/lib/Basic/Targets/WebAssembly.h | 42 ++++++------
clang/lib/Basic/Targets/X86.h | 2 +
clang/lib/CodeGen/CGHLSLRuntime.cpp | 56 ++++++++++++++++
clang/lib/CodeGen/CGHLSLRuntime.h | 4 ++
clang/lib/CodeGen/CodeGenModule.cpp | 7 +-
clang/lib/Parse/ParseHLSL.cpp | 1 +
clang/lib/Sema/SemaDeclAttr.cpp | 9 +++
clang/lib/Sema/SemaHLSL.cpp | 64 +++++++++++++++++++
.../SemaTemplate/address_space-dependent.cpp | 4 +-
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 2 +-
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 4 +-
llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp | 4 +-
llvm/lib/Target/SPIRV/SPIRVUtils.cpp | 2 +
llvm/lib/Target/SPIRV/SPIRVUtils.h | 2 +
28 files changed, 280 insertions(+), 28 deletions(-)
diff --git a/clang/include/clang/Basic/AddressSpaces.h b/clang/include/clang/Basic/AddressSpaces.h
index 7b723d508fff17..b829d9d3be8688 100644
--- a/clang/include/clang/Basic/AddressSpaces.h
+++ b/clang/include/clang/Basic/AddressSpaces.h
@@ -59,6 +59,10 @@ enum class LangAS : unsigned {
// HLSL specific address spaces.
hlsl_groupshared,
+ // Vulkan specific address spaces.
+ vulkan_input,
+ vulkan_output,
+
// Wasm specific address spaces.
wasm_funcref,
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 24cfb5ddb6d4ca..4eb3962dcc583b 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -164,6 +164,16 @@ def HLSLBufferObj : SubsetSubject<HLSLBuffer,
[{isa<HLSLBufferDecl>(S)}],
"cbuffer/tbuffer">;
+def HLSLInputBuiltin : SubsetSubject<Var,
+ [{S->hasGlobalStorage() && S->getType().isConstQualified() &&
+ S->getStorageClass()==StorageClass::SC_Extern}],
+ "static const global variables">;
+
+def HLSLOutputBuiltin : SubsetSubject<Var,
+ [{S->hasGlobalStorage() && !S->getType().isConstQualified() &&
+ S->getStorageClass()==StorageClass::SC_Extern}],
+ "static const global variables">;
+
def ClassTmpl : SubsetSubject<CXXRecord, [{S->getDescribedClassTemplate()}],
"class templates">;
@@ -4588,6 +4598,22 @@ def HLSLNumThreads: InheritableAttr {
let Documentation = [NumThreadsDocs];
}
+def HLSLVkExtBuiltinInput: InheritableAttr {
+ let Spellings = [CXX11<"vk", "ext_builtin_input">];
+ let Args = [IntArgument<"BuiltIn">];
+ let Subjects = SubjectList<[HLSLInputBuiltin], ErrorDiag>;
+ let LangOpts = [HLSL];
+ let Documentation = [VkExtBuiltinInputDocs];
+}
+
+def HLSLVkExtBuiltinOutput: InheritableAttr {
+ let Spellings = [CXX11<"vk", "ext_builtin_output">];
+ let Args = [IntArgument<"BuiltIn">];
+ let Subjects = SubjectList<[HLSLOutputBuiltin], ErrorDiag>;
+ let LangOpts = [HLSL];
+ let Documentation = [VkExtBuiltinOutputDocs];
+}
+
def HLSLSV_GroupIndex: HLSLAnnotationAttr {
let Spellings = [HLSLAnnotation<"SV_GroupIndex">];
let Subjects = SubjectList<[ParmVar, GlobalVar]>;
@@ -4595,6 +4621,13 @@ def HLSLSV_GroupIndex: HLSLAnnotationAttr {
let Documentation = [HLSLSV_GroupIndexDocs];
}
+def HLSLSV_GroupID: HLSLAnnotationAttr {
+ let Spellings = [HLSLAnnotation<"SV_GroupID">];
+ let Subjects = SubjectList<[ParmVar, GlobalVar]>;
+ let LangOpts = [HLSL];
+ let Documentation = [HLSLSV_GroupIDDocs];
+}
+
def HLSLResourceBinding: InheritableAttr {
let Spellings = [HLSLAnnotation<"register">];
let Subjects = SubjectList<[HLSLBufferObj, ExternalGlobalVar], ErrorDiag>;
diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td
index 23c8eb2d163c86..29925c18f1bd95 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -7670,6 +7670,26 @@ The full documentation is available here: https://docs.microsoft.com/en-us/windo
}];
}
+def VkExtBuiltinInputDocs : Documentation {
+ let Category = DocCatFunction;
+ let Content = [{
+The ``vk::ext_builtin_input`` attribute applies to HLSL global variables.
+The ``BuiltIn`` value is the ID of the BuiltIn in the SPIR-V specification.
+
+The full documentation is available here: https://microsoft.github.io/hlsl-specs/proposals/0011-inline-spirv.html
+ }];
+}
+
+def VkExtBuiltinOutputDocs : Documentation {
+ let Category = DocCatFunction;
+ let Content = [{
+The ``vk::ext_builtin_output`` attribute applies to HLSL global variables.
+The ``BuiltIn`` value is the ID of the BuiltIn in the SPIR-V specification.
+
+The full documentation is available here: https://microsoft.github.io/hlsl-specs/proposals/0011-inline-spirv.html
+ }];
+}
+
def HLSLSV_ShaderTypeAttrDocs : Documentation {
let Category = DocCatFunction;
let Content = [{
@@ -7828,6 +7848,17 @@ The full documentation is available here: https://docs.microsoft.com/en-us/windo
}];
}
+def HLSLSV_GroupIDDocs : Documentation {
+ let Category = DocCatFunction;
+ let Content = [{
+The ``SV_GroupID`` semantic, when applied to an input parameter, specifies a
+data binding to map the group index to the specified parameter. This attribute
+is only supported in compute shaders.
+
+The full documentation is available here: https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sv-groupid
+ }];
+}
+
def HLSLResourceBindingDocs : Documentation {
let Category = DocCatFunction;
let Content = [{
diff --git a/clang/include/clang/Sema/ParsedAttr.h b/clang/include/clang/Sema/ParsedAttr.h
index 22cbd0d90ee432..d1ed6241a10b88 100644
--- a/clang/include/clang/Sema/ParsedAttr.h
+++ b/clang/include/clang/Sema/ParsedAttr.h
@@ -624,6 +624,17 @@ class ParsedAttr final
}
}
+ LangAS asVulkanLangAS() const {
+ switch (getParsedKind()) {
+ case ParsedAttr::AT_HLSLVkExtBuiltinInput:
+ return LangAS::vulkan_input;
+ case ParsedAttr::AT_HLSLVkExtBuiltinOutput:
+ return LangAS::vulkan_output;
+ default:
+ return LangAS::Default;
+ }
+ }
+
AttributeCommonInfo::Kind getKind() const {
return AttributeCommonInfo::Kind(Info.AttrKind);
}
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index 06c541dec08cc8..d3ec45a367ae41 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -117,6 +117,8 @@ class SemaHLSL : public SemaBase {
void emitLogicalOperatorFixIt(Expr *LHS, Expr *RHS, BinaryOperatorKind Opc);
void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
+ void handleVkExtBuiltinInput(Decl *D, const ParsedAttr &AL);
+ void handleVkExtBuiltinOutput(Decl *D, const ParsedAttr &AL);
void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
void handlePackOffsetAttr(Decl *D, const ParsedAttr &AL);
diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp
index 6d8db5cf4ffd22..d36658f7eb8114 100644
--- a/clang/lib/AST/TypePrinter.cpp
+++ b/clang/lib/AST/TypePrinter.cpp
@@ -2545,6 +2545,10 @@ std::string Qualifiers::getAddrSpaceAsString(LangAS AS) {
return "__funcref";
case LangAS::hlsl_groupshared:
return "groupshared";
+ case LangAS::vulkan_input:
+ return "Input";
+ case LangAS::vulkan_output:
+ return "Output";
default:
return std::to_string(toTargetAddressSpace(AS));
}
diff --git a/clang/lib/Basic/TargetInfo.cpp b/clang/lib/Basic/TargetInfo.cpp
index 86befb1cbc74fc..1d2f1e44fd7560 100644
--- a/clang/lib/Basic/TargetInfo.cpp
+++ b/clang/lib/Basic/TargetInfo.cpp
@@ -47,6 +47,8 @@ static const LangASMap FakeAddrSpaceMap = {
11, // ptr32_uptr
12, // ptr64
13, // hlsl_groupshared
+ 14, // vulkan_input
+ 15, // vulkan_output
20, // wasm_funcref
};
diff --git a/clang/lib/Basic/Targets/AArch64.h b/clang/lib/Basic/Targets/AArch64.h
index ea3e4015d84265..1038cbc74d7966 100644
--- a/clang/lib/Basic/Targets/AArch64.h
+++ b/clang/lib/Basic/Targets/AArch64.h
@@ -44,6 +44,8 @@ static const unsigned ARM64AddrSpaceMap[] = {
static_cast<unsigned>(AArch64AddrSpace::ptr32_uptr),
static_cast<unsigned>(AArch64AddrSpace::ptr64),
0, // hlsl_groupshared
+ 0, // vulkan_input
+ 0, // vulkan_output
// Wasm address space values for this target are dummy values,
// as it is only enabled for Wasm targets.
20, // wasm_funcref
diff --git a/clang/lib/Basic/Targets/AMDGPU.cpp b/clang/lib/Basic/Targets/AMDGPU.cpp
index 078819183afdac..51a4ba22e6afe0 100644
--- a/clang/lib/Basic/Targets/AMDGPU.cpp
+++ b/clang/lib/Basic/Targets/AMDGPU.cpp
@@ -59,6 +59,8 @@ const LangASMap AMDGPUTargetInfo::AMDGPUDefIsGenMap = {
llvm::AMDGPUAS::FLAT_ADDRESS, // ptr32_uptr
llvm::AMDGPUAS::FLAT_ADDRESS, // ptr64
llvm::AMDGPUAS::FLAT_ADDRESS, // hlsl_groupshared
+ llvm::AMDGPUAS::FLAT_ADDRESS, // vulkan_input
+ llvm::AMDGPUAS::FLAT_ADDRESS, // vulkan_output
};
const LangASMap AMDGPUTargetInfo::AMDGPUDefIsPrivMap = {
@@ -83,6 +85,8 @@ const LangASMap AMDGPUTargetInfo::AMDGPUDefIsPrivMap = {
llvm::AMDGPUAS::FLAT_ADDRESS, // ptr32_uptr
llvm::AMDGPUAS::FLAT_ADDRESS, // ptr64
llvm::AMDGPUAS::FLAT_ADDRESS, // hlsl_groupshared
+ llvm::AMDGPUAS::FLAT_ADDRESS, // vulkan_input
+ llvm::AMDGPUAS::FLAT_ADDRESS, // vulkan_output
};
} // namespace targets
diff --git a/clang/lib/Basic/Targets/DirectX.h b/clang/lib/Basic/Targets/DirectX.h
index ab22d1281a4df7..3e44a1d5a266f5 100644
--- a/clang/lib/Basic/Targets/DirectX.h
+++ b/clang/lib/Basic/Targets/DirectX.h
@@ -42,6 +42,8 @@ static const unsigned DirectXAddrSpaceMap[] = {
0, // ptr32_uptr
0, // ptr64
3, // hlsl_groupshared
+ 0, // vulkan_input
+ 0, // vulkan_output
// Wasm address space values for this target are dummy values,
// as it is only enabled for Wasm targets.
20, // wasm_funcref
diff --git a/clang/lib/Basic/Targets/NVPTX.h b/clang/lib/Basic/Targets/NVPTX.h
index 165b28a60fb2a9..b97c3a2bd0a131 100644
--- a/clang/lib/Basic/Targets/NVPTX.h
+++ b/clang/lib/Basic/Targets/NVPTX.h
@@ -45,6 +45,8 @@ static const unsigned NVPTXAddrSpaceMap[] = {
0, // ptr32_uptr
0, // ptr64
0, // hlsl_groupshared
+ 0, // vulkan_input
+ 0, // vulkan_output
// Wasm address space values for this target are dummy values,
// as it is only enabled for Wasm targets.
20, // wasm_funcref
diff --git a/clang/lib/Basic/Targets/SPIR.h b/clang/lib/Basic/Targets/SPIR.h
index 85e4bd920d8535..3c9624b5daa172 100644
--- a/clang/lib/Basic/Targets/SPIR.h
+++ b/clang/lib/Basic/Targets/SPIR.h
@@ -47,6 +47,8 @@ static const unsigned SPIRDefIsPrivMap[] = {
0, // ptr32_uptr
0, // ptr64
0, // hlsl_groupshared
+ 7, // vulkan_input
+ 8, // vulkan_output
// Wasm address space values for this target are dummy values,
// as it is only enabled for Wasm targets.
20, // wasm_funcref
@@ -80,6 +82,8 @@ static const unsigned SPIRDefIsGenMap[] = {
0, // ptr32_uptr
0, // ptr64
0, // hlsl_groupshared
+ 7, // vulkan_input
+ 8, // vulkan_output
// Wasm address space values for this target are dummy values,
// as it is only enabled for Wasm targets.
20, // wasm_funcref
diff --git a/clang/lib/Basic/Targets/SystemZ.h b/clang/lib/Basic/Targets/SystemZ.h
index ef9a07033a6e4f..8e49c7f4260b86 100644
--- a/clang/lib/Basic/Targets/SystemZ.h
+++ b/clang/lib/Basic/Targets/SystemZ.h
@@ -42,6 +42,8 @@ static const unsigned ZOSAddressMap[] = {
1, // ptr32_uptr
0, // ptr64
0, // hlsl_groupshared
+ 0, // vulkan_input
+ 0, // vulkan_output
0 // wasm_funcref
};
diff --git a/clang/lib/Basic/Targets/TCE.h b/clang/lib/Basic/Targets/TCE.h
index d6280b02f07b25..6a4b4dee80762a 100644
--- a/clang/lib/Basic/Targets/TCE.h
+++ b/clang/lib/Basic/Targets/TCE.h
@@ -51,6 +51,8 @@ static const unsigned TCEOpenCLAddrSpaceMap[] = {
0, // ptr32_uptr
0, // ptr64
0, // hlsl_groupshared
+ 0, // vulkan_input
+ 0, // vulkan_output
// Wasm address space values for this target are dummy values,
// as it is only enabled for Wasm targets.
20, // wasm_funcref
diff --git a/clang/lib/Basic/Targets/WebAssembly.h b/clang/lib/Basic/Targets/WebAssembly.h
index 6c2fe8049ff47a..0edd8ab8c8494e 100644
--- a/clang/lib/Basic/Targets/WebAssembly.h
+++ b/clang/lib/Basic/Targets/WebAssembly.h
@@ -22,26 +22,28 @@ namespace clang {
namespace targets {
static const unsigned WebAssemblyAddrSpaceMap[] = {
- 0, // Default
- 0, // opencl_global
- 0, // opencl_local
- 0, // opencl_constant
- 0, // opencl_private
- 0, // opencl_generic
- 0, // opencl_global_device
- 0, // opencl_global_host
- 0, // cuda_device
- 0, // cuda_constant
- 0, // cuda_shared
- 0, // sycl_global
- 0, // sycl_global_device
- 0, // sycl_global_host
- 0, // sycl_local
- 0, // sycl_private
- 0, // ptr32_sptr
- 0, // ptr32_uptr
- 0, // ptr64
- 0, // hlsl_groupshared
+ 0, // Default
+ 0, // opencl_global
+ 0, // opencl_local
+ 0, // opencl_constant
+ 0, // opencl_private
+ 0, // opencl_generic
+ 0, // opencl_global_device
+ 0, // opencl_global_host
+ 0, // cuda_device
+ 0, // cuda_constant
+ 0, // cuda_shared
+ 0, // sycl_global
+ 0, // sycl_global_device
+ 0, // sycl_global_host
+ 0, // sycl_local
+ 0, // sycl_private
+ 0, // ptr32_sptr
+ 0, // ptr32_uptr
+ 0, // ptr64
+ 0, // hlsl_groupshared
+ 0, // vulkan_input
+ 0, // vulkan_output
20, // wasm_funcref
};
diff --git a/clang/lib/Basic/Targets/X86.h b/clang/lib/Basic/Targets/X86.h
index e2eba63b992355..4171184297476b 100644
--- a/clang/lib/Basic/Targets/X86.h
+++ b/clang/lib/Basic/Targets/X86.h
@@ -46,6 +46,8 @@ static const unsigned X86AddrSpaceMap[] = {
271, // ptr32_uptr
272, // ptr64
0, // hlsl_groupshared
+ 0, // vulkan_input
+ 0, // vulkan_output
// Wasm address space values for this target are dummy values,
// as it is only enabled for Wasm targets.
20, // wasm_funcref
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp
index 7ba0d615018181..042f7a74485dc9 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.cpp
+++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp
@@ -375,10 +375,43 @@ static Value *buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty) {
return B.CreateCall(F, {B.getInt32(0)});
}
+llvm::Value *CGHLSLRuntime::emitInputBuiltin(IRBuilder<> &B,
+ const ParmVarDecl &D,
+ llvm::Type *Ty, unsigned BuiltInID,
+ LangAS AddressSpace) {
+ LLVMContext &Ctx = CGM.getLLVMContext();
+ // FIXME: keep track of which global is created, and reuse them.
+ llvm::GlobalVariable *GV = new llvm::GlobalVariable(
+ CGM.getModule(), Ty,
+ /* isConstant= */ false, GlobalValue::ExternalLinkage,
+ /* Initializer= */ nullptr,
+ /* Name= */ D.getName(),
+ /* InsertBefore= */ nullptr,
+ /* ThreadLocalMode= */ GlobalValue::NotThreadLocal,
+ /* AddressSpace */ CGM.getTarget().getTargetAddressSpace(AddressSpace));
+
+ MDNode *Decoration =
+ MDNode::get(Ctx, {ConstantAsMetadata::get(B.getInt32(11 /* BuiltIn */)),
+ ConstantAsMetadata::get(B.getInt32(BuiltInID))});
+ MDNode *Val = MDNode::get(Ctx, {Decoration});
+ GV->setMetadata("spirv.Decorations", Val);
+ return B.CreateLoad(Ty, GV);
+}
+
llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B,
const ParmVarDecl &D,
llvm::Type *Ty) {
assert(D.hasAttrs() && "Entry parameter missing annotation attribute!");
+
+ if (D.hasAttr<HLSLSV_GroupIDAttr>()) {
+ if (getArch() == llvm::Triple::spirv)
+ return emitInputBuiltin(B, D, Ty, /* WorkgroupID */ 26,
+ LangAS::vulkan_input);
+ else
+ // FIXME: getIntrinsic(getGroupIDIntrinsic())
+ return nullptr;
+ }
+
if (D.hasAttr<HLSLSV_GroupIndexAttr>()) {
llvm::Function *DxGroupIndex =
CGM.getIntrinsic(Intrinsic::dx_flattened_thread_id_in_group);
@@ -525,6 +558,29 @@ void CGHLSLRuntime::generateGlobalCtorDtorCalls() {
void CGHLSLRuntime::handleGlobalVarDefinition(const VarDecl *VD,
llvm::GlobalVariable *GV) {
+
+ if (HLSLVkExtBuiltinInputAttr *BuiltinAttr =
+ VD->getAttr<HLSLVkExtBuiltinInputAttr>()) {
+ LLVMContext &Ctx = CGM.getLLVMContext();
+ IRBuilder<> B(Ctx);
+ MDNode *Decoration = MDNode::get(
+ Ctx, {ConstantAsMetadata::get(B.getInt32(11 /* BuiltIn */)),
+ ConstantAsMetadata::get(B.getInt32(BuiltinAttr->getBuiltIn()))});
+ MDNode *Val = MDNode::get(Ctx, {Decoration});
+ GV->setMetadata("spirv.Decorations", Val);
+ }
+
+ if (HLSLVkExtBuiltinOutputAttr *BuiltinAttr =
+ VD->getAttr<HLSLVkExtBuiltinOutputAttr>()) {
+ LLVMContext &Ctx = CGM.getLLVMContext();
+ IRBuilder<> B(Ctx);
+ MDNode *Decoration = MDNode::get(
+ Ctx, {ConstantAsMetadata::get(B.getInt32(11 /* BuiltIn */)),
+ ConstantAsMetadata::get(B.getInt32(BuiltinAttr->getBuiltIn()))});
+ MDNode *Val = MDNode::get(Ctx, {Decoration});
+ GV->setMetadata("spirv.Decorations", Val);
+ }
+
// If the global variable has resource binding, add it to the list of globals
// that need resource binding initialization.
const HLSLResourceBindingAttr *RBA = VD->getAttr<HLSLResourceBindingAttr>();
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index ff810cc535c087..6ca35416a1f015 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -20,6 +20,7 @@
#include "llvm/IR/IntrinsicsDirectX.h"
#include "llvm/IR/IntrinsicsSPIRV.h"
+#include "clang/Basic/AddressSpaces.h"
#include "clang/Basic/Builtins.h"
#include "clang/Basic/HLSLRuntime.h"
@@ -123,6 +124,9 @@ class CGHLSLRuntime {
protected:
CodeGenModule &CGM;
+ llvm::Value *emitInputBuiltin(llvm::IRBuilder<> &B, const ParmVarDecl &D,
+ llvm::Type *Ty, unsigned BuiltInID,
+ LangAS AddressSpace);
llvm::Value *emitInputSemantic(llvm::IRBuilder<> &B, const ParmVarDecl &D,
llvm::Type *Ty);
diff --git a/clang/lib/CodeGen/CodeGenModule.cpp b/clang/lib/CodeGen/CodeGenModule.cpp
index ba376f9ecfacde..5b93df8b6c6fc9 100644
--- a/clang/lib/CodeGen/CodeGenModule.cpp
+++ b/clang/lib/CodeGen/CodeGenModule.cpp
@@ -5047,6 +5047,10 @@ CodeGenModule::GetOrCreateLLVMGlobal(StringRef MangledName, llvm::Type *Ty,
if (LangOpts.OpenMP && !LangOpts.OpenMPSimd)
getOpenMPRuntime().registerTargetGlobalVariable(D, GV);
+ // HLSL related end of code gen work items.
+ if (LangOpts.HLSL)
+ getHLSLRuntime().handleGlobalVarDefinition(D, GV);
+
// FIXME: This code is overly simple and should be merged with other global
// handling.
GV->setConstant(D->getType().isConstantStorage(getContext(), false, false));
@@ -5628,9 +5632,6 @@ void CodeGenModule::EmitGlobalVarDefinition(const VarDecl *D,
getCUDARuntime().handleVarRegistration(D, *GV);
}
- if (LangOpts.HLSL)
- getHLSLRuntime().handleGlobalVarDefinition(D, GV);
-
GV->setInitializer(Init);
if (emitter)
emitter->finalize(GV);
diff --git a/clang/lib/Parse/ParseHLSL.cpp b/clang/lib/Parse/ParseHLSL.cpp
index b36ea4012c26e1..1f95a3268e84e6 100644
--- a/clang/lib/Parse/ParseHLSL.cpp
+++ b/clang/lib/Parse/ParseHLSL.cpp
@@ -281,6 +281,7 @@ void Parser::ParseHLSLAnnotations(ParsedAttributes &Attrs,
Diag(Loc, diag::err_unknown_hlsl_semantic) << II;
return;
case ParsedAttr::AT_HLSLSV_GroupIndex:
+ case ParsedAttr::AT_HLSLSV_GroupID:
case ParsedAttr::AT_HLSLSV_DispatchThreadID:
break;
default:
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index d05d326178e1b8..aa9f89ab458bd3 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -6987,12 +6987,21 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
case ParsedAttr::AT_HLSLNumThreads:
S.HLSL().handleNumThreadsAttr(D, AL);
break;
+ case ParsedAttr::AT_HLSLVkExtBuiltinInput:
+ S.HLSL().handleVkExtBuiltinInput(D, AL);
+ break;
+ case ParsedAttr::AT_HLSLVkExtBuiltinOutput:
+ S.HLSL().handleVkExtBuiltinOutput(D, AL);
+ break;
case ParsedAttr::AT_HLSLWaveSize:
S.HLSL().handleWaveSizeAttr(D, AL);
break;
case ParsedAttr::AT_HLSLSV_GroupIndex:
handleSimpleAttribute<HLSLSV_GroupIndexAttr>(S, D, AL);
break;
+ case ParsedAttr::AT_HLSLSV_GroupID:
+ handleSimpleAttribute<HLSLSV_GroupIDAttr>(S, D, AL);
+ break;
case ParsedAttr::AT_HLSLGroupSharedAddressSpace:
handleSimpleAttribute<HLSLGroupSharedAddressSpaceAttr>(S, D, AL);
break;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 65b0d9cd65637f..187fe5dda0a311 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -434,6 +434,7 @@ void SemaHLSL::CheckSemanticAnnotation(
switch (AnnotationAttr->getKind()) {
case attr::HLSLSV_DispatchThreadID:
case attr::HLSLSV_GroupIndex:
+ case attr::HLSLSV_GroupID:
if (ST == llvm::Triple::Compute)
return;
DiagnoseAttrStageMismatch(AnnotationAttr, ST, {llvm::Triple::Compute});
@@ -645,6 +646,69 @@ void SemaHLSL::emitLogicalOperatorFixIt(Expr *LHS, Expr *RHS,
<< NewFnName << FixItHint::CreateReplacement(FullRange, OS.str());
}
+namespace {
+
+std::optional<uint32_t> validateVkExtBuiltin(Decl *D, const ParsedAttr &AL,
+ ASTContext &Ctx, Sema &SemaRef) {
+ llvm::Triple::OSType OSType = Ctx.getTargetInfo().getTriple().getOS();
+ if (!isa<VarDecl>(D)) {
+ // FIXME
+ SemaRef.Diag(AL.getLoc(), diag::err_hlsl_missing_resource_class);
+ return std::nullopt;
+ }
+
+ if (OSType != llvm::Triple::OSType::Vulkan) {
+ // FIXME
+ SemaRef.Diag(AL.getLoc(), diag::err_hlsl_missing_resource_class);
+ return std::nullopt;
+ }
+
+ uint32_t ID;
+ if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), ID)) {
+ // FIXME
+ SemaRef.Diag(AL.getLoc(), diag::err_hlsl_missing_resource_class);
+ return std::nullopt;
+ }
+
+ return ID;
+}
+
+} // anonymous namespace
+
+void SemaHLSL::handleVkExtBuiltinInput(Decl *D, const ParsedAttr &AL) {
+ std::optional<uint32_t> ID =
+ validateVkExtBuiltin(D, AL, getASTContext(), SemaRef);
+ if (!ID)
+ return;
+
+ VarDecl *VD = cast<VarDecl>(D);
+ QualType NewType =
+ SemaRef.Context.getAddrSpaceQualType(VD->getType(), LangAS::vulkan_input);
+ VD->setType(NewType);
+
+ HLSLVkExtBuiltinInputAttr *NewAttr = ::new (getASTContext())
+ HLSLVkExtBuiltinInputAttr(getASTContext(), AL, *ID);
+ assert(NewAttr);
+ VD->addAttr(NewAttr);
+}
+
+void SemaHLSL::handleVkExtBuiltinOutput(Decl *D, const ParsedAttr &AL) {
+ std::optional<uint32_t> ID =
+ validateVkExtBuiltin(D, AL, getASTContext(), SemaRef);
+ if (!ID)
+ return;
+
+ VarDecl *VD = cast<VarDecl>(D);
+ QualType NewType = SemaRef.Context.getAddrSpaceQualType(
+ VD->getType(), LangAS::vulkan_output);
+ VD->setType(NewType);
+
+ HLSLVkExtBuiltinOutputAttr *NewAttr = ::new (getASTContext())
+ HLSLVkExtBuiltinOutputAttr(getASTContext(), AL, *ID);
+ assert(NewAttr);
+ VD->addAttr(NewAttr);
+}
+
void SemaHLSL::handleNumThreadsAttr(Decl *D, const ParsedAttr &AL) {
llvm::VersionTuple SMVersion =
getASTContext().getTargetInfo().getTriple().getOSVersion();
diff --git a/clang/test/SemaTemplate/address_space-dependent.cpp b/clang/test/SemaTemplate/address_space-dependent.cpp
index 2ca9b8007ab418..ca79be5f43e1e8 100644
--- a/clang/test/SemaTemplate/address_space-dependent.cpp
+++ b/clang/test/SemaTemplate/address_space-dependent.cpp
@@ -43,7 +43,7 @@ void neg() {
template <long int I>
void tooBig() {
- __attribute__((address_space(I))) int *bounds; // expected-error {{address space is larger than the maximum supported (8388586)}}
+ __attribute__((address_space(I))) int *bounds; // expected-error {{address space is larger than the maximum supported (8388584)}}
}
template <long int I>
@@ -101,7 +101,7 @@ int main() {
car<1, 2, 3>(); // expected-note {{in instantiation of function template specialization 'car<1, 2, 3>' requested here}}
HasASTemplateFields<1> HASTF;
neg<-1>(); // expected-note {{in instantiation of function template specialization 'neg<-1>' requested here}}
- correct<0x7FFFE9>();
+ correct<0x7FFFE7>();
tooBig<8388650>(); // expected-note {{in instantiation of function template specialization 'tooBig<8388650L>' requested here}}
__attribute__((address_space(1))) char *x;
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index f66506beaa6ed6..c65989d37b3033 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -689,7 +689,7 @@ Register SPIRVGlobalRegistry::buildGlobalVariable(
if (IsConst && ST.isOpenCLEnv())
buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {});
- if (GVar && GVar->getAlign().valueOrOne().value() != 1) {
+ if (ST.isOpenCLEnv() && GVar && GVar->getAlign().valueOrOne().value() != 1) {
unsigned Alignment = (unsigned)GVar->getAlign().valueOrOne().value();
buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Alignment, {Alignment});
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index be38b22f70c583..7f624186ffee6d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -3179,8 +3179,10 @@ bool SPIRVInstructionSelector::selectGlobalValue(
unsigned AddrSpace = GV->getAddressSpace();
SPIRV::StorageClass::StorageClass Storage =
addressSpaceToStorageClass(AddrSpace, STI);
+ bool isIOVariable = Storage == SPIRV::StorageClass::Input ||
+ Storage == SPIRV::StorageClass::Output;
bool HasLnkTy = GV->getLinkage() != GlobalValue::InternalLinkage &&
- Storage != SPIRV::StorageClass::Function;
+ Storage != SPIRV::StorageClass::Function && !isIOVariable;
SPIRV::LinkageType::LinkageType LnkType =
(GV->isDeclaration() || GV->hasAvailableExternallyLinkage())
? SPIRV::LinkageType::Import
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 460f0127d4ffcd..52fefc7007f517 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -112,6 +112,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
const LLT p5 =
LLT::pointer(5, PSize); // Input, SPV_INTEL_usm_storage_classes (Device)
const LLT p6 = LLT::pointer(6, PSize); // SPV_INTEL_usm_storage_classes (Host)
+ const LLT p7 = LLT::pointer(7, PSize); // Input
+ const LLT p8 = LLT::pointer(8, PSize); // Output
// TODO: remove copy-pasting here by using concatenation in some way.
auto allPtrsScalarsAndVectors = {
@@ -148,7 +150,7 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1,
p2, p3, p4, p5, p6};
- auto allPtrs = {p0, p1, p2, p3, p4, p5, p6};
+ auto allPtrs = {p0, p1, p2, p3, p4, p5, p6, p7, p8};
bool IsExtendedInts =
ST.canUseExtension(
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index f9b361e163c909..ec51005ffcb078 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -185,6 +185,8 @@ addressSpaceToStorageClass(unsigned AddrSpace, const SPIRVSubtarget &STI) {
: SPIRV::StorageClass::CrossWorkgroup;
case 7:
return SPIRV::StorageClass::Input;
+ case 8:
+ return SPIRV::StorageClass::Output;
default:
report_fatal_error("Unknown address space");
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index 11fd3a5c61dcae..80bbab2c17812f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -161,6 +161,8 @@ storageClassToAddressSpace(SPIRV::StorageClass::StorageClass SC) {
return 6;
case SPIRV::StorageClass::Input:
return 7;
+ case SPIRV::StorageClass::Output:
+ return 8;
default:
report_fatal_error("Unable to get address space id");
}
More information about the llvm-commits
mailing list