[clang] [llvm] [SPIR-V] DRAFT: ext_builtin_input/ext_builtin_output (PR #115187)

Nathan Gauër via cfe-commits cfe-commits at lists.llvm.org
Wed Nov 6 09:48:28 PST 2024


https://github.com/Keenuts created https://github.com/llvm/llvm-project/pull/115187

Draft PR to explore adding semantics & inline SPIR-V for builtins.

Current usage is
```hlsl
// RUN: %clang --driver-mode=dxc -T cs_6_6 -spirv %s -O3 -E main

[[vk::ext_builtin_input(/* NumWorkGroups */ 24)]]
extern const uint3 numWorkGroups;

[[vk::ext_builtin_output(/* Random */ 25)]]
extern uint3 output;

[shader("compute")]
[numthreads(32, 1, 1)]
void main() {
  output = numWorkGroups;
}
```

>From d06fb83ba0515e3097d375b0b9d8cd2922dd0f54 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Wed, 6 Nov 2024 17:48:13 +0100
Subject: [PATCH] [SPIR-V] DRAFT: ext_builtin_input/ext_builtin_output
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

do not submit

Signed-off-by: Nathan Gauër <brioche at google.com>
---
 clang/include/clang/Basic/AddressSpaces.h     |  4 ++
 clang/include/clang/Basic/Attr.td             | 26 ++++++++
 clang/include/clang/Basic/AttrDocs.td         | 20 ++++++
 clang/include/clang/Sema/ParsedAttr.h         | 11 ++++
 clang/include/clang/Sema/SemaHLSL.h           |  2 +
 clang/lib/AST/TypePrinter.cpp                 |  4 ++
 clang/lib/Basic/TargetInfo.cpp                |  2 +
 clang/lib/Basic/Targets/AArch64.h             |  2 +
 clang/lib/Basic/Targets/AMDGPU.cpp            |  4 ++
 clang/lib/Basic/Targets/DirectX.h             |  2 +
 clang/lib/Basic/Targets/NVPTX.h               |  2 +
 clang/lib/Basic/Targets/SPIR.h                |  4 ++
 clang/lib/Basic/Targets/SystemZ.h             |  2 +
 clang/lib/Basic/Targets/TCE.h                 |  2 +
 clang/lib/Basic/Targets/WebAssembly.h         | 42 +++++++------
 clang/lib/Basic/Targets/X86.h                 |  2 +
 clang/lib/CodeGen/CGHLSLRuntime.cpp           | 23 +++++++
 clang/lib/CodeGen/CodeGenModule.cpp           |  7 ++-
 clang/lib/Sema/SemaDeclAttr.cpp               |  6 ++
 clang/lib/Sema/SemaHLSL.cpp                   | 63 +++++++++++++++++++
 .../SemaTemplate/address_space-dependent.cpp  |  4 +-
 .../Target/SPIRV/SPIRVInstructionSelector.cpp |  4 +-
 llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp  |  4 +-
 llvm/lib/Target/SPIRV/SPIRVUtils.cpp          |  2 +
 llvm/lib/Target/SPIRV/SPIRVUtils.h            |  2 +
 25 files changed, 219 insertions(+), 27 deletions(-)

diff --git a/clang/include/clang/Basic/AddressSpaces.h b/clang/include/clang/Basic/AddressSpaces.h
index 7b723d508fff17..b829d9d3be8688 100644
--- a/clang/include/clang/Basic/AddressSpaces.h
+++ b/clang/include/clang/Basic/AddressSpaces.h
@@ -59,6 +59,10 @@ enum class LangAS : unsigned {
   // HLSL specific address spaces.
   hlsl_groupshared,
 
+  // Vulkan specific address spaces.
+  vulkan_input,
+  vulkan_output,
+
   // Wasm specific address spaces.
   wasm_funcref,
 
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 24cfb5ddb6d4ca..3b88735cc5fe94 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -164,6 +164,16 @@ def HLSLBufferObj : SubsetSubject<HLSLBuffer,
                     [{isa<HLSLBufferDecl>(S)}],
                     "cbuffer/tbuffer">;
 
+def HLSLInputBuiltin : SubsetSubject<Var,
+                             [{S->hasGlobalStorage() && S->getType().isConstQualified() &&
+                               S->getStorageClass()==StorageClass::SC_Extern}],
+                             "static const global variables">;
+
+def HLSLOutputBuiltin : SubsetSubject<Var,
+                             [{S->hasGlobalStorage() && !S->getType().isConstQualified() &&
+                               S->getStorageClass()==StorageClass::SC_Extern}],
+                             "static const global variables">;
+
 def ClassTmpl : SubsetSubject<CXXRecord, [{S->getDescribedClassTemplate()}],
                               "class templates">;
 
@@ -4588,6 +4598,22 @@ def HLSLNumThreads: InheritableAttr {
   let Documentation = [NumThreadsDocs];
 }
 
+def HLSLVkExtBuiltinInput: InheritableAttr {
+  let Spellings = [CXX11<"vk", "ext_builtin_input">];
+  let Args = [IntArgument<"BuiltIn">];
+  let Subjects = SubjectList<[HLSLInputBuiltin], ErrorDiag>;
+  let LangOpts = [HLSL];
+  let Documentation = [VkExtBuiltinInputDocs];
+}
+
+def HLSLVkExtBuiltinOutput: InheritableAttr {
+  let Spellings = [CXX11<"vk", "ext_builtin_output">];
+  let Args = [IntArgument<"BuiltIn">];
+  let Subjects = SubjectList<[HLSLOutputBuiltin], ErrorDiag>;
+  let LangOpts = [HLSL];
+  let Documentation = [VkExtBuiltinOutputDocs];
+}
+
 def HLSLSV_GroupIndex: HLSLAnnotationAttr {
   let Spellings = [HLSLAnnotation<"SV_GroupIndex">];
   let Subjects = SubjectList<[ParmVar, GlobalVar]>;
diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td
index 23c8eb2d163c86..5c7508a407baf7 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -7670,6 +7670,26 @@ The full documentation is available here: https://docs.microsoft.com/en-us/windo
   }];
 }
 
+def VkExtBuiltinInputDocs : Documentation {
+  let Category = DocCatFunction;
+  let Content = [{
+The ``vk::ext_builtin_input`` attribute applies to HLSL global variables.
+The ``BuiltIn`` value is the ID of the BuiltIn in the SPIR-V specification.
+
+The full documentation is available here: https://microsoft.github.io/hlsl-specs/proposals/0011-inline-spirv.html
+  }];
+}
+
+def VkExtBuiltinOutputDocs : Documentation {
+  let Category = DocCatFunction;
+  let Content = [{
+The ``vk::ext_builtin_output`` attribute applies to HLSL global variables.
+The ``BuiltIn`` value is the ID of the BuiltIn in the SPIR-V specification.
+
+The full documentation is available here: https://microsoft.github.io/hlsl-specs/proposals/0011-inline-spirv.html
+  }];
+}
+
 def HLSLSV_ShaderTypeAttrDocs : Documentation {
   let Category = DocCatFunction;
   let Content = [{
diff --git a/clang/include/clang/Sema/ParsedAttr.h b/clang/include/clang/Sema/ParsedAttr.h
index 22cbd0d90ee432..d1ed6241a10b88 100644
--- a/clang/include/clang/Sema/ParsedAttr.h
+++ b/clang/include/clang/Sema/ParsedAttr.h
@@ -624,6 +624,17 @@ class ParsedAttr final
     }
   }
 
+  LangAS asVulkanLangAS() const {
+    switch (getParsedKind()) {
+    case ParsedAttr::AT_HLSLVkExtBuiltinInput:
+      return LangAS::vulkan_input;
+    case ParsedAttr::AT_HLSLVkExtBuiltinOutput:
+      return LangAS::vulkan_output;
+    default:
+      return LangAS::Default;
+    }
+  }
+
   AttributeCommonInfo::Kind getKind() const {
     return AttributeCommonInfo::Kind(Info.AttrKind);
   }
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index 06c541dec08cc8..d3ec45a367ae41 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -117,6 +117,8 @@ class SemaHLSL : public SemaBase {
   void emitLogicalOperatorFixIt(Expr *LHS, Expr *RHS, BinaryOperatorKind Opc);
 
   void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
+  void handleVkExtBuiltinInput(Decl *D, const ParsedAttr &AL);
+  void handleVkExtBuiltinOutput(Decl *D, const ParsedAttr &AL);
   void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
   void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
   void handlePackOffsetAttr(Decl *D, const ParsedAttr &AL);
diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp
index 6d8db5cf4ffd22..d36658f7eb8114 100644
--- a/clang/lib/AST/TypePrinter.cpp
+++ b/clang/lib/AST/TypePrinter.cpp
@@ -2545,6 +2545,10 @@ std::string Qualifiers::getAddrSpaceAsString(LangAS AS) {
     return "__funcref";
   case LangAS::hlsl_groupshared:
     return "groupshared";
+  case LangAS::vulkan_input:
+    return "Input";
+  case LangAS::vulkan_output:
+    return "Output";
   default:
     return std::to_string(toTargetAddressSpace(AS));
   }
diff --git a/clang/lib/Basic/TargetInfo.cpp b/clang/lib/Basic/TargetInfo.cpp
index 86befb1cbc74fc..1d2f1e44fd7560 100644
--- a/clang/lib/Basic/TargetInfo.cpp
+++ b/clang/lib/Basic/TargetInfo.cpp
@@ -47,6 +47,8 @@ static const LangASMap FakeAddrSpaceMap = {
     11, // ptr32_uptr
     12, // ptr64
     13, // hlsl_groupshared
+    14, // vulkan_input
+    15, // vulkan_output
     20, // wasm_funcref
 };
 
diff --git a/clang/lib/Basic/Targets/AArch64.h b/clang/lib/Basic/Targets/AArch64.h
index ea3e4015d84265..1038cbc74d7966 100644
--- a/clang/lib/Basic/Targets/AArch64.h
+++ b/clang/lib/Basic/Targets/AArch64.h
@@ -44,6 +44,8 @@ static const unsigned ARM64AddrSpaceMap[] = {
     static_cast<unsigned>(AArch64AddrSpace::ptr32_uptr),
     static_cast<unsigned>(AArch64AddrSpace::ptr64),
     0, // hlsl_groupshared
+    0, // vulkan_input
+    0, // vulkan_output
     // Wasm address space values for this target are dummy values,
     // as it is only enabled for Wasm targets.
     20, // wasm_funcref
diff --git a/clang/lib/Basic/Targets/AMDGPU.cpp b/clang/lib/Basic/Targets/AMDGPU.cpp
index 078819183afdac..51a4ba22e6afe0 100644
--- a/clang/lib/Basic/Targets/AMDGPU.cpp
+++ b/clang/lib/Basic/Targets/AMDGPU.cpp
@@ -59,6 +59,8 @@ const LangASMap AMDGPUTargetInfo::AMDGPUDefIsGenMap = {
     llvm::AMDGPUAS::FLAT_ADDRESS,     // ptr32_uptr
     llvm::AMDGPUAS::FLAT_ADDRESS,     // ptr64
     llvm::AMDGPUAS::FLAT_ADDRESS,     // hlsl_groupshared
+    llvm::AMDGPUAS::FLAT_ADDRESS,     // vulkan_input
+    llvm::AMDGPUAS::FLAT_ADDRESS,     // vulkan_output
 };
 
 const LangASMap AMDGPUTargetInfo::AMDGPUDefIsPrivMap = {
@@ -83,6 +85,8 @@ const LangASMap AMDGPUTargetInfo::AMDGPUDefIsPrivMap = {
     llvm::AMDGPUAS::FLAT_ADDRESS, // ptr32_uptr
     llvm::AMDGPUAS::FLAT_ADDRESS, // ptr64
     llvm::AMDGPUAS::FLAT_ADDRESS, // hlsl_groupshared
+    llvm::AMDGPUAS::FLAT_ADDRESS, // vulkan_input
+    llvm::AMDGPUAS::FLAT_ADDRESS, // vulkan_output
 
 };
 } // namespace targets
diff --git a/clang/lib/Basic/Targets/DirectX.h b/clang/lib/Basic/Targets/DirectX.h
index ab22d1281a4df7..3e44a1d5a266f5 100644
--- a/clang/lib/Basic/Targets/DirectX.h
+++ b/clang/lib/Basic/Targets/DirectX.h
@@ -42,6 +42,8 @@ static const unsigned DirectXAddrSpaceMap[] = {
     0, // ptr32_uptr
     0, // ptr64
     3, // hlsl_groupshared
+    0, // vulkan_input
+    0, // vulkan_output
     // Wasm address space values for this target are dummy values,
     // as it is only enabled for Wasm targets.
     20, // wasm_funcref
diff --git a/clang/lib/Basic/Targets/NVPTX.h b/clang/lib/Basic/Targets/NVPTX.h
index 165b28a60fb2a9..b97c3a2bd0a131 100644
--- a/clang/lib/Basic/Targets/NVPTX.h
+++ b/clang/lib/Basic/Targets/NVPTX.h
@@ -45,6 +45,8 @@ static const unsigned NVPTXAddrSpaceMap[] = {
     0, // ptr32_uptr
     0, // ptr64
     0, // hlsl_groupshared
+    0, // vulkan_input
+    0, // vulkan_output
     // Wasm address space values for this target are dummy values,
     // as it is only enabled for Wasm targets.
     20, // wasm_funcref
diff --git a/clang/lib/Basic/Targets/SPIR.h b/clang/lib/Basic/Targets/SPIR.h
index 85e4bd920d8535..3c9624b5daa172 100644
--- a/clang/lib/Basic/Targets/SPIR.h
+++ b/clang/lib/Basic/Targets/SPIR.h
@@ -47,6 +47,8 @@ static const unsigned SPIRDefIsPrivMap[] = {
     0, // ptr32_uptr
     0, // ptr64
     0, // hlsl_groupshared
+    7, // vulkan_input
+    8, // vulkan_output
     // Wasm address space values for this target are dummy values,
     // as it is only enabled for Wasm targets.
     20, // wasm_funcref
@@ -80,6 +82,8 @@ static const unsigned SPIRDefIsGenMap[] = {
     0, // ptr32_uptr
     0, // ptr64
     0, // hlsl_groupshared
+    7, // vulkan_input
+    8, // vulkan_output
     // Wasm address space values for this target are dummy values,
     // as it is only enabled for Wasm targets.
     20, // wasm_funcref
diff --git a/clang/lib/Basic/Targets/SystemZ.h b/clang/lib/Basic/Targets/SystemZ.h
index ef9a07033a6e4f..8e49c7f4260b86 100644
--- a/clang/lib/Basic/Targets/SystemZ.h
+++ b/clang/lib/Basic/Targets/SystemZ.h
@@ -42,6 +42,8 @@ static const unsigned ZOSAddressMap[] = {
     1, // ptr32_uptr
     0, // ptr64
     0, // hlsl_groupshared
+    0, // vulkan_input
+    0, // vulkan_output
     0  // wasm_funcref
 };
 
diff --git a/clang/lib/Basic/Targets/TCE.h b/clang/lib/Basic/Targets/TCE.h
index d6280b02f07b25..6a4b4dee80762a 100644
--- a/clang/lib/Basic/Targets/TCE.h
+++ b/clang/lib/Basic/Targets/TCE.h
@@ -51,6 +51,8 @@ static const unsigned TCEOpenCLAddrSpaceMap[] = {
     0, // ptr32_uptr
     0, // ptr64
     0, // hlsl_groupshared
+    0, // vulkan_input
+    0, // vulkan_output
     // Wasm address space values for this target are dummy values,
     // as it is only enabled for Wasm targets.
     20, // wasm_funcref
diff --git a/clang/lib/Basic/Targets/WebAssembly.h b/clang/lib/Basic/Targets/WebAssembly.h
index 6c2fe8049ff47a..0edd8ab8c8494e 100644
--- a/clang/lib/Basic/Targets/WebAssembly.h
+++ b/clang/lib/Basic/Targets/WebAssembly.h
@@ -22,26 +22,28 @@ namespace clang {
 namespace targets {
 
 static const unsigned WebAssemblyAddrSpaceMap[] = {
-    0, // Default
-    0, // opencl_global
-    0, // opencl_local
-    0, // opencl_constant
-    0, // opencl_private
-    0, // opencl_generic
-    0, // opencl_global_device
-    0, // opencl_global_host
-    0, // cuda_device
-    0, // cuda_constant
-    0, // cuda_shared
-    0, // sycl_global
-    0, // sycl_global_device
-    0, // sycl_global_host
-    0, // sycl_local
-    0, // sycl_private
-    0, // ptr32_sptr
-    0, // ptr32_uptr
-    0, // ptr64
-    0, // hlsl_groupshared
+    0,  // Default
+    0,  // opencl_global
+    0,  // opencl_local
+    0,  // opencl_constant
+    0,  // opencl_private
+    0,  // opencl_generic
+    0,  // opencl_global_device
+    0,  // opencl_global_host
+    0,  // cuda_device
+    0,  // cuda_constant
+    0,  // cuda_shared
+    0,  // sycl_global
+    0,  // sycl_global_device
+    0,  // sycl_global_host
+    0,  // sycl_local
+    0,  // sycl_private
+    0,  // ptr32_sptr
+    0,  // ptr32_uptr
+    0,  // ptr64
+    0,  // hlsl_groupshared
+    0,  // vulkan_input
+    0,  // vulkan_output
     20, // wasm_funcref
 };
 
diff --git a/clang/lib/Basic/Targets/X86.h b/clang/lib/Basic/Targets/X86.h
index e2eba63b992355..4171184297476b 100644
--- a/clang/lib/Basic/Targets/X86.h
+++ b/clang/lib/Basic/Targets/X86.h
@@ -46,6 +46,8 @@ static const unsigned X86AddrSpaceMap[] = {
     271, // ptr32_uptr
     272, // ptr64
     0,   // hlsl_groupshared
+    0,   // vulkan_input
+    0,   // vulkan_output
     // Wasm address space values for this target are dummy values,
     // as it is only enabled for Wasm targets.
     20, // wasm_funcref
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp
index 7ba0d615018181..59aeeb7546fef4 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.cpp
+++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp
@@ -525,6 +525,29 @@ void CGHLSLRuntime::generateGlobalCtorDtorCalls() {
 
 void CGHLSLRuntime::handleGlobalVarDefinition(const VarDecl *VD,
                                               llvm::GlobalVariable *GV) {
+
+  if (HLSLVkExtBuiltinInputAttr *BuiltinAttr =
+          VD->getAttr<HLSLVkExtBuiltinInputAttr>()) {
+    LLVMContext &Ctx = CGM.getLLVMContext();
+    IRBuilder<> B(Ctx);
+    MDNode *Decoration = MDNode::get(
+        Ctx, {ConstantAsMetadata::get(B.getInt32(11 /* BuiltIn */)),
+              ConstantAsMetadata::get(B.getInt32(BuiltinAttr->getBuiltIn()))});
+    MDNode *Val = MDNode::get(Ctx, {Decoration});
+    GV->setMetadata("spirv.Decorations", Val);
+  }
+
+  if (HLSLVkExtBuiltinOutputAttr *BuiltinAttr =
+          VD->getAttr<HLSLVkExtBuiltinOutputAttr>()) {
+    LLVMContext &Ctx = CGM.getLLVMContext();
+    IRBuilder<> B(Ctx);
+    MDNode *Decoration = MDNode::get(
+        Ctx, {ConstantAsMetadata::get(B.getInt32(11 /* BuiltIn */)),
+              ConstantAsMetadata::get(B.getInt32(BuiltinAttr->getBuiltIn()))});
+    MDNode *Val = MDNode::get(Ctx, {Decoration});
+    GV->setMetadata("spirv.Decorations", Val);
+  }
+
   // If the global variable has resource binding, add it to the list of globals
   // that need resource binding initialization.
   const HLSLResourceBindingAttr *RBA = VD->getAttr<HLSLResourceBindingAttr>();
diff --git a/clang/lib/CodeGen/CodeGenModule.cpp b/clang/lib/CodeGen/CodeGenModule.cpp
index ba376f9ecfacde..5b93df8b6c6fc9 100644
--- a/clang/lib/CodeGen/CodeGenModule.cpp
+++ b/clang/lib/CodeGen/CodeGenModule.cpp
@@ -5047,6 +5047,10 @@ CodeGenModule::GetOrCreateLLVMGlobal(StringRef MangledName, llvm::Type *Ty,
     if (LangOpts.OpenMP && !LangOpts.OpenMPSimd)
       getOpenMPRuntime().registerTargetGlobalVariable(D, GV);
 
+    // HLSL related end of code gen work items.
+    if (LangOpts.HLSL)
+      getHLSLRuntime().handleGlobalVarDefinition(D, GV);
+
     // FIXME: This code is overly simple and should be merged with other global
     // handling.
     GV->setConstant(D->getType().isConstantStorage(getContext(), false, false));
@@ -5628,9 +5632,6 @@ void CodeGenModule::EmitGlobalVarDefinition(const VarDecl *D,
     getCUDARuntime().handleVarRegistration(D, *GV);
   }
 
-  if (LangOpts.HLSL)
-    getHLSLRuntime().handleGlobalVarDefinition(D, GV);
-
   GV->setInitializer(Init);
   if (emitter)
     emitter->finalize(GV);
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index d05d326178e1b8..dec16bf826495f 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -6987,6 +6987,12 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
   case ParsedAttr::AT_HLSLNumThreads:
     S.HLSL().handleNumThreadsAttr(D, AL);
     break;
+  case ParsedAttr::AT_HLSLVkExtBuiltinInput:
+    S.HLSL().handleVkExtBuiltinInput(D, AL);
+    break;
+  case ParsedAttr::AT_HLSLVkExtBuiltinOutput:
+    S.HLSL().handleVkExtBuiltinOutput(D, AL);
+    break;
   case ParsedAttr::AT_HLSLWaveSize:
     S.HLSL().handleWaveSizeAttr(D, AL);
     break;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 65b0d9cd65637f..9efa2f68787fec 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -645,6 +645,69 @@ void SemaHLSL::emitLogicalOperatorFixIt(Expr *LHS, Expr *RHS,
       << NewFnName << FixItHint::CreateReplacement(FullRange, OS.str());
 }
 
+namespace {
+
+std::optional<uint32_t> validateVkExtBuiltin(Decl *D, const ParsedAttr &AL,
+                                             ASTContext &Ctx, Sema &SemaRef) {
+  llvm::Triple::OSType OSType = Ctx.getTargetInfo().getTriple().getOS();
+  if (!isa<VarDecl>(D)) {
+    // FIXME
+    SemaRef.Diag(AL.getLoc(), diag::err_hlsl_missing_resource_class);
+    return std::nullopt;
+  }
+
+  if (OSType != llvm::Triple::OSType::Vulkan) {
+    // FIXME
+    SemaRef.Diag(AL.getLoc(), diag::err_hlsl_missing_resource_class);
+    return std::nullopt;
+  }
+
+  uint32_t ID;
+  if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), ID)) {
+    // FIXME
+    SemaRef.Diag(AL.getLoc(), diag::err_hlsl_missing_resource_class);
+    return std::nullopt;
+  }
+
+  return ID;
+}
+
+} // anonymous namespace
+
+void SemaHLSL::handleVkExtBuiltinInput(Decl *D, const ParsedAttr &AL) {
+  std::optional<uint32_t> ID =
+      validateVkExtBuiltin(D, AL, getASTContext(), SemaRef);
+  if (!ID)
+    return;
+
+  VarDecl *VD = cast<VarDecl>(D);
+  QualType NewType =
+      SemaRef.Context.getAddrSpaceQualType(VD->getType(), LangAS::vulkan_input);
+  VD->setType(NewType);
+
+  HLSLVkExtBuiltinInputAttr *NewAttr = ::new (getASTContext())
+      HLSLVkExtBuiltinInputAttr(getASTContext(), AL, *ID);
+  assert(NewAttr);
+  VD->addAttr(NewAttr);
+}
+
+void SemaHLSL::handleVkExtBuiltinOutput(Decl *D, const ParsedAttr &AL) {
+  std::optional<uint32_t> ID =
+      validateVkExtBuiltin(D, AL, getASTContext(), SemaRef);
+  if (!ID)
+    return;
+
+  VarDecl *VD = cast<VarDecl>(D);
+  QualType NewType = SemaRef.Context.getAddrSpaceQualType(
+      VD->getType(), LangAS::vulkan_output);
+  VD->setType(NewType);
+
+  HLSLVkExtBuiltinOutputAttr *NewAttr = ::new (getASTContext())
+      HLSLVkExtBuiltinOutputAttr(getASTContext(), AL, *ID);
+  assert(NewAttr);
+  VD->addAttr(NewAttr);
+}
+
 void SemaHLSL::handleNumThreadsAttr(Decl *D, const ParsedAttr &AL) {
   llvm::VersionTuple SMVersion =
       getASTContext().getTargetInfo().getTriple().getOSVersion();
diff --git a/clang/test/SemaTemplate/address_space-dependent.cpp b/clang/test/SemaTemplate/address_space-dependent.cpp
index 2ca9b8007ab418..ca79be5f43e1e8 100644
--- a/clang/test/SemaTemplate/address_space-dependent.cpp
+++ b/clang/test/SemaTemplate/address_space-dependent.cpp
@@ -43,7 +43,7 @@ void neg() {
 
 template <long int I>
 void tooBig() {
-  __attribute__((address_space(I))) int *bounds; // expected-error {{address space is larger than the maximum supported (8388586)}}
+  __attribute__((address_space(I))) int *bounds; // expected-error {{address space is larger than the maximum supported (8388584)}}
 }
 
 template <long int I>
@@ -101,7 +101,7 @@ int main() {
   car<1, 2, 3>(); // expected-note {{in instantiation of function template specialization 'car<1, 2, 3>' requested here}}
   HasASTemplateFields<1> HASTF;
   neg<-1>(); // expected-note {{in instantiation of function template specialization 'neg<-1>' requested here}}
-  correct<0x7FFFE9>();
+  correct<0x7FFFE7>();
   tooBig<8388650>(); // expected-note {{in instantiation of function template specialization 'tooBig<8388650L>' requested here}}
 
   __attribute__((address_space(1))) char *x;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index be38b22f70c583..7f624186ffee6d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -3179,8 +3179,10 @@ bool SPIRVInstructionSelector::selectGlobalValue(
   unsigned AddrSpace = GV->getAddressSpace();
   SPIRV::StorageClass::StorageClass Storage =
       addressSpaceToStorageClass(AddrSpace, STI);
+  bool isIOVariable = Storage == SPIRV::StorageClass::Input ||
+                      Storage == SPIRV::StorageClass::Output;
   bool HasLnkTy = GV->getLinkage() != GlobalValue::InternalLinkage &&
-                  Storage != SPIRV::StorageClass::Function;
+                  Storage != SPIRV::StorageClass::Function && !isIOVariable;
   SPIRV::LinkageType::LinkageType LnkType =
       (GV->isDeclaration() || GV->hasAvailableExternallyLinkage())
           ? SPIRV::LinkageType::Import
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 460f0127d4ffcd..52fefc7007f517 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -112,6 +112,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
   const LLT p5 =
       LLT::pointer(5, PSize); // Input, SPV_INTEL_usm_storage_classes (Device)
   const LLT p6 = LLT::pointer(6, PSize); // SPV_INTEL_usm_storage_classes (Host)
+  const LLT p7 = LLT::pointer(7, PSize); // Input
+  const LLT p8 = LLT::pointer(8, PSize); // Output
 
   // TODO: remove copy-pasting here by using concatenation in some way.
   auto allPtrsScalarsAndVectors = {
@@ -148,7 +150,7 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
   auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1,
                                        p2, p3,  p4,  p5,  p6};
 
-  auto allPtrs = {p0, p1, p2, p3, p4, p5, p6};
+  auto allPtrs = {p0, p1, p2, p3, p4, p5, p6, p7, p8};
 
   bool IsExtendedInts =
       ST.canUseExtension(
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index f9b361e163c909..ec51005ffcb078 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -185,6 +185,8 @@ addressSpaceToStorageClass(unsigned AddrSpace, const SPIRVSubtarget &STI) {
                : SPIRV::StorageClass::CrossWorkgroup;
   case 7:
     return SPIRV::StorageClass::Input;
+  case 8:
+    return SPIRV::StorageClass::Output;
   default:
     report_fatal_error("Unknown address space");
   }
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index 11fd3a5c61dcae..80bbab2c17812f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -161,6 +161,8 @@ storageClassToAddressSpace(SPIRV::StorageClass::StorageClass SC) {
     return 6;
   case SPIRV::StorageClass::Input:
     return 7;
+  case SPIRV::StorageClass::Output:
+    return 8;
   default:
     report_fatal_error("Unable to get address space id");
   }



More information about the cfe-commits mailing list