[clang] [HLSL] Implement `SpirvType` and `SpirvOpaqueType` (PR #134034)
Cassandra Beckley via cfe-commits
cfe-commits at lists.llvm.org
Tue Apr 1 23:34:51 PDT 2025
https://github.com/cassiebeckley updated https://github.com/llvm/llvm-project/pull/134034
>From 78ac1bc4225b41bc4b9fbd9fd9ab9dc82a2953ca Mon Sep 17 00:00:00 2001
From: Cassandra Beckley <cbeckley at google.com>
Date: Tue, 1 Apr 2025 23:12:02 -0700
Subject: [PATCH 1/2] [HLSL] Implement `SpirvType` and `SpirvOpaqueType`
This implements the design proposed by [Representing SpirvType in
Clang's Type System](https://github.com/llvm/wg-hlsl/pull/181). It
creates `HLSLInlineSpirvType` as a new `Type` subclass, and
`__hlsl_spirv_type` as a new builtin type template to create such a
type.
This new type is lowered to the `spirv.Type` target extension type, as
described in [Target Extension Types for Inline SPIR-V and Decorated
Types](https://github.com/llvm/wg-hlsl/blob/main/proposals/0017-inline-spirv-and-decorated-types.md).
---
clang/include/clang-c/Index.h | 3 +-
clang/include/clang/AST/ASTContext.h | 5 +
clang/include/clang/AST/ASTNodeTraverser.h | 18 +++
clang/include/clang/AST/PropertiesBase.td | 1 +
clang/include/clang/AST/RecursiveASTVisitor.h | 11 ++
clang/include/clang/AST/Type.h | 142 +++++++++++++++++-
clang/include/clang/AST/TypeLoc.h | 19 +++
clang/include/clang/AST/TypeProperties.td | 18 +++
clang/include/clang/Basic/BuiltinTemplates.td | 18 ++-
.../clang/Basic/DiagnosticSemaKinds.td | 3 +
clang/include/clang/Basic/TypeNodes.td | 1 +
.../clang/Serialization/ASTRecordReader.h | 2 +
.../clang/Serialization/ASTRecordWriter.h | 14 ++
.../clang/Serialization/TypeBitCodes.def | 1 +
clang/lib/AST/ASTContext.cpp | 59 ++++++++
clang/lib/AST/ASTImporter.cpp | 42 ++++++
clang/lib/AST/ASTStructuralEquivalence.cpp | 17 +++
clang/lib/AST/ExprConstant.cpp | 1 +
clang/lib/AST/ItaniumMangle.cpp | 40 ++++-
clang/lib/AST/MicrosoftMangle.cpp | 5 +
clang/lib/AST/Type.cpp | 14 ++
clang/lib/AST/TypePrinter.cpp | 48 ++++++
clang/lib/CodeGen/CGDebugInfo.cpp | 8 +
clang/lib/CodeGen/CGDebugInfo.h | 1 +
clang/lib/CodeGen/CodeGenFunction.cpp | 2 +
clang/lib/CodeGen/CodeGenTypes.cpp | 6 +
clang/lib/CodeGen/ItaniumCXXABI.cpp | 2 +
clang/lib/CodeGen/Targets/SPIR.cpp | 90 ++++++++++-
clang/lib/Headers/CMakeLists.txt | 1 +
clang/lib/Headers/hlsl.h | 4 +
clang/lib/Headers/hlsl/hlsl_spirv.h | 30 ++++
clang/lib/Sema/SemaExpr.cpp | 1 +
clang/lib/Sema/SemaLookup.cpp | 21 ++-
clang/lib/Sema/SemaTemplate.cpp | 103 ++++++++++++-
clang/lib/Sema/SemaTemplateDeduction.cpp | 2 +
clang/lib/Sema/SemaType.cpp | 1 +
clang/lib/Sema/TreeTransform.h | 7 +
clang/lib/Serialization/ASTReader.cpp | 9 ++
clang/lib/Serialization/ASTWriter.cpp | 4 +
.../test/AST/HLSL/Inputs/pch_spirv_type.hlsl | 6 +
clang/test/AST/HLSL/ast-dump-SpirvType.hlsl | 27 ++++
clang/test/AST/HLSL/pch_spirv_type.hlsl | 17 +++
clang/test/AST/HLSL/vector-alias.hlsl | 105 +++++++------
.../inline/SpirvType.alignment.hlsl | 16 ++
.../inline/SpirvType.dx.error.hlsl | 12 ++
clang/test/CodeGenHLSL/inline/SpirvType.hlsl | 68 +++++++++
.../inline/SpirvType.incomplete.hlsl | 14 ++
.../inline/SpirvType.literal.error.hlsl | 11 ++
clang/tools/libclang/CIndex.cpp | 5 +
clang/tools/libclang/CXType.cpp | 1 +
.../TableGen/ClangBuiltinTemplatesEmitter.cpp | 72 +++++++--
51 files changed, 1052 insertions(+), 76 deletions(-)
create mode 100644 clang/lib/Headers/hlsl/hlsl_spirv.h
create mode 100644 clang/test/AST/HLSL/Inputs/pch_spirv_type.hlsl
create mode 100644 clang/test/AST/HLSL/ast-dump-SpirvType.hlsl
create mode 100644 clang/test/AST/HLSL/pch_spirv_type.hlsl
create mode 100644 clang/test/CodeGenHLSL/inline/SpirvType.alignment.hlsl
create mode 100644 clang/test/CodeGenHLSL/inline/SpirvType.dx.error.hlsl
create mode 100644 clang/test/CodeGenHLSL/inline/SpirvType.hlsl
create mode 100644 clang/test/CodeGenHLSL/inline/SpirvType.incomplete.hlsl
create mode 100644 clang/test/CodeGenHLSL/inline/SpirvType.literal.error.hlsl
diff --git a/clang/include/clang-c/Index.h b/clang/include/clang-c/Index.h
index 38e2417dcd181..757f8a3afc758 100644
--- a/clang/include/clang-c/Index.h
+++ b/clang/include/clang-c/Index.h
@@ -3034,7 +3034,8 @@ enum CXTypeKind {
/* HLSL Types */
CXType_HLSLResource = 179,
- CXType_HLSLAttributedResource = 180
+ CXType_HLSLAttributedResource = 180,
+ CXType_HLSLInlineSpirv = 181
};
/**
diff --git a/clang/include/clang/AST/ASTContext.h b/clang/include/clang/AST/ASTContext.h
index a24f30815e6b9..c62f9f7672010 100644
--- a/clang/include/clang/AST/ASTContext.h
+++ b/clang/include/clang/AST/ASTContext.h
@@ -260,6 +260,7 @@ class ASTContext : public RefCountedBase<ASTContext> {
DependentBitIntTypes;
mutable llvm::FoldingSet<BTFTagAttributedType> BTFTagAttributedTypes;
llvm::FoldingSet<HLSLAttributedResourceType> HLSLAttributedResourceTypes;
+ llvm::FoldingSet<HLSLInlineSpirvType> HLSLInlineSpirvTypes;
mutable llvm::FoldingSet<CountAttributedType> CountAttributedTypes;
@@ -1795,6 +1796,10 @@ class ASTContext : public RefCountedBase<ASTContext> {
QualType Wrapped, QualType Contained,
const HLSLAttributedResourceType::Attributes &Attrs);
+ QualType getHLSLInlineSpirvType(uint32_t Opcode, uint32_t Size,
+ uint32_t Alignment,
+ ArrayRef<SpirvOperand> Operands);
+
QualType
getSubstTemplateTypeParmType(QualType Replacement, Decl *AssociatedDecl,
unsigned Index,
diff --git a/clang/include/clang/AST/ASTNodeTraverser.h b/clang/include/clang/AST/ASTNodeTraverser.h
index f086d8134a64b..fd9108221590e 100644
--- a/clang/include/clang/AST/ASTNodeTraverser.h
+++ b/clang/include/clang/AST/ASTNodeTraverser.h
@@ -450,6 +450,24 @@ class ASTNodeTraverser
if (!Contained.isNull())
Visit(Contained);
}
+ void VisitHLSLInlineSpirvType(const HLSLInlineSpirvType *T) {
+ for (auto &Operand : T->getOperands()) {
+ using SpirvOperandKind = SpirvOperand::SpirvOperandKind;
+
+ switch (Operand.getKind()) {
+ case SpirvOperandKind::kConstantId:
+ case SpirvOperandKind::kLiteral:
+ break;
+
+ case SpirvOperandKind::kTypeId:
+ Visit(Operand.getResultType());
+ break;
+
+ default:
+ llvm_unreachable("Invalid SpirvOperand kind!");
+ }
+ }
+ }
void VisitSubstTemplateTypeParmType(const SubstTemplateTypeParmType *) {}
void
VisitSubstTemplateTypeParmPackType(const SubstTemplateTypeParmPackType *T) {
diff --git a/clang/include/clang/AST/PropertiesBase.td b/clang/include/clang/AST/PropertiesBase.td
index 5171555008ac9..7d5e6671fec7d 100644
--- a/clang/include/clang/AST/PropertiesBase.td
+++ b/clang/include/clang/AST/PropertiesBase.td
@@ -147,6 +147,7 @@ def UInt64 : CountPropertyType<"uint64_t">;
def UnaryTypeTransformKind : EnumPropertyType<"UnaryTransformType::UTTKind">;
def VectorKind : EnumPropertyType<"VectorKind">;
def TypeCoupledDeclRefInfo : PropertyType;
+def HLSLSpirvOperand : PropertyType<"SpirvOperand"> { let PassByReference = 1; }
def ExceptionSpecInfo : PropertyType<"FunctionProtoType::ExceptionSpecInfo"> {
let BufferElementTypes = [ QualType ];
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index 0530996ed20d3..255e39a46db09 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -1154,6 +1154,14 @@ DEF_TRAVERSE_TYPE(BTFTagAttributedType,
DEF_TRAVERSE_TYPE(HLSLAttributedResourceType,
{ TRY_TO(TraverseType(T->getWrappedType())); })
+DEF_TRAVERSE_TYPE(HLSLInlineSpirvType, {
+ for (auto &Operand : T->getOperands()) {
+ if (Operand.isConstant() || Operand.isType()) {
+ TRY_TO(TraverseType(Operand.getResultType()));
+ }
+ }
+})
+
DEF_TRAVERSE_TYPE(ParenType, { TRY_TO(TraverseType(T->getInnerType())); })
DEF_TRAVERSE_TYPE(MacroQualifiedType,
@@ -1457,6 +1465,9 @@ DEF_TRAVERSE_TYPELOC(BTFTagAttributedType,
DEF_TRAVERSE_TYPELOC(HLSLAttributedResourceType,
{ TRY_TO(TraverseTypeLoc(TL.getWrappedLoc())); })
+DEF_TRAVERSE_TYPELOC(HLSLInlineSpirvType,
+ { TRY_TO(TraverseType(TL.getType())); })
+
DEF_TRAVERSE_TYPELOC(ElaboratedType, {
if (TL.getQualifierLoc()) {
TRY_TO(TraverseNestedNameSpecifierLoc(TL.getQualifierLoc()));
diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index cfd417068abb7..f351e68d5297d 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -2652,6 +2652,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
bool isHLSLSpecificType() const; // Any HLSL specific type
bool isHLSLBuiltinIntangibleType() const; // Any HLSL builtin intangible type
bool isHLSLAttributedResourceType() const;
+ bool isHLSLInlineSpirvType() const;
bool isHLSLResourceRecord() const;
bool isHLSLIntangibleType()
const; // Any HLSL intangible type (builtin, array, class)
@@ -6330,6 +6331,140 @@ class HLSLAttributedResourceType : public Type, public llvm::FoldingSetNode {
findHandleTypeOnResource(const Type *RT);
};
+/// Instances of this class represent operands to a SPIR-V type instruction.
+class SpirvOperand {
+public:
+ enum SpirvOperandKind : unsigned char {
+ kInvalid, ///< Uninitialized.
+ kConstantId, ///< Integral value to represent as a SPIR-V OpConstant
+ ///< instruction ID.
+ kLiteral, ///< Integral value to represent as an immediate literal.
+ kTypeId, ///< Type to represent as a SPIR-V type ID.
+
+ kMax,
+ };
+
+private:
+ SpirvOperandKind Kind = kInvalid;
+
+ QualType ResultType;
+ llvm::APInt Value; // Signedness of constants is represented by ResultType.
+
+public:
+ SpirvOperand() : Kind(kInvalid), ResultType() {}
+
+ SpirvOperand(SpirvOperandKind Kind, QualType ResultType, llvm::APInt Value)
+ : Kind(Kind), ResultType(ResultType), Value(Value) {}
+
+ SpirvOperand(const SpirvOperand &Other) { *this = Other; }
+ ~SpirvOperand() {}
+
+ SpirvOperand &operator=(const SpirvOperand &Other) {
+ this->Kind = Other.Kind;
+ this->ResultType = Other.ResultType;
+ this->Value = Other.Value;
+ return *this;
+ }
+
+ bool operator==(const SpirvOperand &Other) const {
+ return Kind == Other.Kind && ResultType == Other.ResultType &&
+ Value == Other.Value;
+ }
+
+ bool operator!=(const SpirvOperand &Other) const { return !(*this == Other); }
+
+ SpirvOperandKind getKind() const { return Kind; }
+
+ bool isValid() const { return Kind != kInvalid && Kind < kMax; }
+ bool isConstant() const { return Kind == kConstantId; }
+ bool isLiteral() const { return Kind == kLiteral; }
+ bool isType() const { return Kind == kTypeId; }
+
+ llvm::APInt getValue() const {
+ assert((isConstant() || isLiteral()) &&
+ "This is not an operand with a value!");
+ return Value;
+ }
+
+ QualType getResultType() const {
+ assert((isConstant() || isType()) &&
+ "This is not an operand with a result type!");
+ return ResultType;
+ }
+
+ static SpirvOperand createConstant(QualType ResultType, llvm::APInt Val) {
+ return SpirvOperand(kConstantId, ResultType, Val);
+ }
+
+ static SpirvOperand createLiteral(llvm::APInt Val) {
+ return SpirvOperand(kLiteral, QualType(), Val);
+ }
+
+ static SpirvOperand createType(QualType T) {
+ return SpirvOperand(kTypeId, T, llvm::APSInt());
+ }
+
+ void Profile(llvm::FoldingSetNodeID &ID) const {
+ ID.AddInteger(Kind);
+ ID.AddPointer(ResultType.getAsOpaquePtr());
+ Value.Profile(ID);
+ }
+};
+
+/// Represents an arbitrary, user-specified SPIR-V type instruction.
+class HLSLInlineSpirvType final
+ : public Type,
+ public llvm::FoldingSetNode,
+ private llvm::TrailingObjects<HLSLInlineSpirvType, SpirvOperand> {
+ friend class ASTContext; // ASTContext creates these
+ friend TrailingObjects;
+
+private:
+ uint32_t Opcode;
+ uint32_t Size;
+ uint32_t Alignment;
+ size_t NumOperands;
+
+ HLSLInlineSpirvType(uint32_t Opcode, uint32_t Size, uint32_t Alignment,
+ ArrayRef<SpirvOperand> Operands)
+ : Type(HLSLInlineSpirv, QualType(), TypeDependence::None), Opcode(Opcode),
+ Size(Size), Alignment(Alignment), NumOperands(Operands.size()) {
+ for (size_t I = 0; I < NumOperands; I++) {
+ getTrailingObjects<SpirvOperand>()[I] = Operands[I];
+ }
+ }
+
+public:
+ uint32_t getOpcode() const { return Opcode; }
+ uint32_t getSize() const { return Size; }
+ uint32_t getAlignment() const { return Alignment; }
+ ArrayRef<SpirvOperand> getOperands() const {
+ return {getTrailingObjects<SpirvOperand>(), NumOperands};
+ }
+
+ bool isSugared() const { return false; }
+ QualType desugar() const { return QualType(this, 0); }
+
+ void Profile(llvm::FoldingSetNodeID &ID) {
+ Profile(ID, Opcode, Size, Alignment, getOperands());
+ }
+
+ static void Profile(llvm::FoldingSetNodeID &ID, uint32_t Opcode,
+ uint32_t Size, uint32_t Alignment,
+ ArrayRef<SpirvOperand> Operands) {
+ ID.AddInteger(Opcode);
+ ID.AddInteger(Size);
+ ID.AddInteger(Alignment);
+ for (auto &Operand : Operands) {
+ Operand.Profile(ID);
+ }
+ }
+
+ static bool classof(const Type *T) {
+ return T->getTypeClass() == HLSLInlineSpirv;
+ }
+};
+
class TemplateTypeParmType : public Type, public llvm::FoldingSetNode {
friend class ASTContext; // ASTContext creates these
@@ -8458,13 +8593,18 @@ inline bool Type::isHLSLBuiltinIntangibleType() const {
}
inline bool Type::isHLSLSpecificType() const {
- return isHLSLBuiltinIntangibleType() || isHLSLAttributedResourceType();
+ return isHLSLBuiltinIntangibleType() || isHLSLAttributedResourceType() ||
+ isHLSLInlineSpirvType();
}
inline bool Type::isHLSLAttributedResourceType() const {
return isa<HLSLAttributedResourceType>(this);
}
+inline bool Type::isHLSLInlineSpirvType() const {
+ return isa<HLSLInlineSpirvType>(this);
+}
+
inline bool Type::isTemplateTypeParmType() const {
return isa<TemplateTypeParmType>(CanonicalType);
}
diff --git a/clang/include/clang/AST/TypeLoc.h b/clang/include/clang/AST/TypeLoc.h
index 92661b8b13fe0..53c7ea8c65df2 100644
--- a/clang/include/clang/AST/TypeLoc.h
+++ b/clang/include/clang/AST/TypeLoc.h
@@ -973,6 +973,25 @@ class HLSLAttributedResourceTypeLoc
}
};
+struct HLSLInlineSpirvTypeLocInfo {
+ SourceLocation Loc;
+}; // Nothing.
+
+class HLSLInlineSpirvTypeLoc
+ : public ConcreteTypeLoc<UnqualTypeLoc, HLSLInlineSpirvTypeLoc,
+ HLSLInlineSpirvType, HLSLInlineSpirvTypeLocInfo> {
+public:
+ SourceLocation getSpirvTypeLoc() const { return getLocalData()->Loc; }
+ void setSpirvTypeLoc(SourceLocation loc) const { getLocalData()->Loc = loc; }
+
+ SourceRange getLocalSourceRange() const {
+ return SourceRange(getSpirvTypeLoc(), getSpirvTypeLoc());
+ }
+ void initializeLocal(ASTContext &Context, SourceLocation loc) {
+ setSpirvTypeLoc(loc);
+ }
+};
+
struct ObjCObjectTypeLocInfo {
SourceLocation TypeArgsLAngleLoc;
SourceLocation TypeArgsRAngleLoc;
diff --git a/clang/include/clang/AST/TypeProperties.td b/clang/include/clang/AST/TypeProperties.td
index 391fd26a086f7..784c2104f1bb2 100644
--- a/clang/include/clang/AST/TypeProperties.td
+++ b/clang/include/clang/AST/TypeProperties.td
@@ -719,6 +719,24 @@ let Class = HLSLAttributedResourceType in {
}]>;
}
+let Class = HLSLInlineSpirvType in {
+ def : Property<"opcode", UInt32> {
+ let Read = [{ node->getOpcode() }];
+ }
+ def : Property<"size", UInt32> {
+ let Read = [{ node->getSize() }];
+ }
+ def : Property<"alignment", UInt32> {
+ let Read = [{ node->getAlignment() }];
+ }
+ def : Property<"operands", Array<HLSLSpirvOperand>> {
+ let Read = [{ node->getOperands() }];
+ }
+ def : Creator<[{
+ return ctx.getHLSLInlineSpirvType(opcode, size, alignment, operands);
+ }]>;
+}
+
let Class = DependentAddressSpaceType in {
def : Property<"pointeeType", QualType> {
let Read = [{ node->getPointeeType() }];
diff --git a/clang/include/clang/Basic/BuiltinTemplates.td b/clang/include/clang/Basic/BuiltinTemplates.td
index d46ce063d2f7e..5b9672b395955 100644
--- a/clang/include/clang/Basic/BuiltinTemplates.td
+++ b/clang/include/clang/Basic/BuiltinTemplates.td
@@ -28,25 +28,37 @@ class BuiltinNTTP<string type_name> : TemplateArg<""> {
}
def SizeT : BuiltinNTTP<"size_t"> {}
+def Uint32T: BuiltinNTTP<"uint32_t"> {}
class BuiltinTemplate<list<TemplateArg> template_head> {
list<TemplateArg> TemplateHead = template_head;
}
+class CPlusPlusBuiltinTemplate<list<TemplateArg> template_head> : BuiltinTemplate<template_head>;
+
+class HLSLBuiltinTemplate<list<TemplateArg> template_head> : BuiltinTemplate<template_head>;
+
// template <template <class T, T... Ints> IntSeq, class T, T N>
-def __make_integer_seq : BuiltinTemplate<
+def __make_integer_seq : CPlusPlusBuiltinTemplate<
[Template<[Class<"T">, NTTP<"T", "Ints", /*is_variadic=*/1>], "IntSeq">, Class<"T">, NTTP<"T", "N">]>;
// template <size_t, class... T>
-def __type_pack_element : BuiltinTemplate<
+def __type_pack_element : CPlusPlusBuiltinTemplate<
[SizeT, Class<"T", /*is_variadic=*/1>]>;
// template <template <class... Args> BaseTemplate,
// template <class TypeMember> HasTypeMember,
// class HasNoTypeMember
// class... Ts>
-def __builtin_common_type : BuiltinTemplate<
+def __builtin_common_type : CPlusPlusBuiltinTemplate<
[Template<[Class<"Args", /*is_variadic=*/1>], "BaseTemplate">,
Template<[Class<"TypeMember">], "HasTypeMember">,
Class<"HasNoTypeMember">,
Class<"Ts", /*is_variadic=*/1>]>;
+
+// template <uint32_t Opcode,
+// uint32_t Size,
+// uint32_t Alignment,
+// typename ...Operands>
+def __hlsl_spirv_type : HLSLBuiltinTemplate<
+[Uint32T, Uint32T, Uint32T, Class<"Operands", /*is_variadic=*/1>]>;
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 265bed2df43cf..287e139f02a2c 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12709,6 +12709,9 @@ def err_hlsl_expect_arg_const_int_one_or_neg_one: Error<
def err_invalid_hlsl_resource_type: Error<
"invalid __hlsl_resource_t type attributes">;
+def err_hlsl_spirv_only: Error<"%0 is only available for the SPIR-V target">;
+def err_hlsl_vk_literal_must_contain_constant: Error<"the argument to vk::Literal must be a vk::integral_constant">;
+
// Layout randomization diagnostics.
def err_non_designated_init_used : Error<
"a randomized struct can only be initialized with a designated initializer">;
diff --git a/clang/include/clang/Basic/TypeNodes.td b/clang/include/clang/Basic/TypeNodes.td
index 7e550ca2992f3..567b8a5ca5a4d 100644
--- a/clang/include/clang/Basic/TypeNodes.td
+++ b/clang/include/clang/Basic/TypeNodes.td
@@ -94,6 +94,7 @@ def ElaboratedType : TypeNode<Type>, NeverCanonical;
def AttributedType : TypeNode<Type>, NeverCanonical;
def BTFTagAttributedType : TypeNode<Type>, NeverCanonical;
def HLSLAttributedResourceType : TypeNode<Type>;
+def HLSLInlineSpirvType : TypeNode<Type>;
def TemplateTypeParmType : TypeNode<Type>, AlwaysDependent, LeafType;
def SubstTemplateTypeParmType : TypeNode<Type>, NeverCanonical;
def SubstTemplateTypeParmPackType : TypeNode<Type>, AlwaysDependent;
diff --git a/clang/include/clang/Serialization/ASTRecordReader.h b/clang/include/clang/Serialization/ASTRecordReader.h
index 7117b7246739b..79d33315d4fee 100644
--- a/clang/include/clang/Serialization/ASTRecordReader.h
+++ b/clang/include/clang/Serialization/ASTRecordReader.h
@@ -214,6 +214,8 @@ class ASTRecordReader
TypeCoupledDeclRefInfo readTypeCoupledDeclRefInfo();
+ SpirvOperand readHLSLSpirvOperand();
+
/// Read a declaration name, advancing Idx.
// DeclarationName readDeclarationName(); (inherited)
DeclarationNameLoc readDeclarationNameLoc(DeclarationName Name);
diff --git a/clang/include/clang/Serialization/ASTRecordWriter.h b/clang/include/clang/Serialization/ASTRecordWriter.h
index 84d77e46016b7..9653b709d3ef5 100644
--- a/clang/include/clang/Serialization/ASTRecordWriter.h
+++ b/clang/include/clang/Serialization/ASTRecordWriter.h
@@ -151,6 +151,20 @@ class ASTRecordWriter
writeBool(Info.isDeref());
}
+ void writeHLSLSpirvOperand(SpirvOperand Op) {
+ QualType ResultType;
+ llvm::APInt Value;
+
+ if (Op.isConstant() || Op.isType())
+ ResultType = Op.getResultType();
+ if (Op.isConstant() || Op.isLiteral())
+ Value = Op.getValue();
+
+ Record->push_back(Op.getKind());
+ writeQualType(ResultType);
+ writeAPInt(Value);
+ }
+
/// Emit a source range.
void AddSourceRange(SourceRange Range, LocSeq *Seq = nullptr) {
return Writer->AddSourceRange(Range, *Record, Seq);
diff --git a/clang/include/clang/Serialization/TypeBitCodes.def b/clang/include/clang/Serialization/TypeBitCodes.def
index 3c78b87805010..b8cde2e370960 100644
--- a/clang/include/clang/Serialization/TypeBitCodes.def
+++ b/clang/include/clang/Serialization/TypeBitCodes.def
@@ -68,5 +68,6 @@ TYPE_BIT_CODE(PackIndexing, PACK_INDEXING, 56)
TYPE_BIT_CODE(CountAttributed, COUNT_ATTRIBUTED, 57)
TYPE_BIT_CODE(ArrayParameter, ARRAY_PARAMETER, 58)
TYPE_BIT_CODE(HLSLAttributedResource, HLSLRESOURCE_ATTRIBUTED, 59)
+TYPE_BIT_CODE(HLSLInlineSpirv, HLSL_INLINE_SPIRV, 60)
#undef TYPE_BIT_CODE
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index 552b5823add36..fb6a7b5a34175 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -2454,6 +2454,19 @@ TypeInfo ASTContext::getTypeInfoImpl(const Type *T) const {
return getTypeInfo(
cast<HLSLAttributedResourceType>(T)->getWrappedType().getTypePtr());
+ case Type::HLSLInlineSpirv: {
+ const auto *ST = cast<HLSLInlineSpirvType>(T);
+ // Size is specified in bytes, convert to bits
+ Width = ST->getSize() * 8;
+ Align = ST->getAlignment();
+ if (Width == 0 && Align == 0) {
+ // We are defaulting to laying out opaque SPIR-V types as 32-bit ints.
+ Width = 32;
+ Align = 32;
+ }
+ break;
+ }
+
case Type::Atomic: {
// Start with the base type information.
TypeInfo Info = getTypeInfo(cast<AtomicType>(T)->getValueType());
@@ -3458,6 +3471,7 @@ static void encodeTypeForFunctionPointerAuth(const ASTContext &Ctx,
return;
}
case Type::HLSLAttributedResource:
+ case Type::HLSLInlineSpirv:
llvm_unreachable("should never get here");
break;
case Type::DeducedTemplateSpecialization:
@@ -4179,6 +4193,7 @@ QualType ASTContext::getVariableArrayDecayedType(QualType type) const {
case Type::DependentBitInt:
case Type::ArrayParameter:
case Type::HLSLAttributedResource:
+ case Type::HLSLInlineSpirv:
llvm_unreachable("type should never be variably-modified");
// These types can be variably-modified but should never need to
@@ -5444,6 +5459,31 @@ QualType ASTContext::getHLSLAttributedResourceType(
return QualType(Ty, 0);
}
+
+QualType ASTContext::getHLSLInlineSpirvType(uint32_t Opcode, uint32_t Size,
+ uint32_t Alignment,
+ ArrayRef<SpirvOperand> Operands) {
+ llvm::FoldingSetNodeID ID;
+ HLSLInlineSpirvType::Profile(ID, Opcode, Size, Alignment, Operands);
+
+ void *InsertPos = nullptr;
+ HLSLInlineSpirvType *Ty =
+ HLSLInlineSpirvTypes.FindNodeOrInsertPos(ID, InsertPos);
+ if (Ty)
+ return QualType(Ty, 0);
+
+ unsigned size = sizeof(HLSLInlineSpirvType);
+ size += Operands.size() * sizeof(SpirvOperand);
+ void *mem = Allocate(size, alignof(HLSLInlineSpirvType));
+
+ Ty = new (mem) HLSLInlineSpirvType(Opcode, Size, Alignment, Operands);
+
+ Types.push_back(Ty);
+ HLSLInlineSpirvTypes.InsertNode(Ty, InsertPos);
+
+ return QualType(Ty, 0);
+}
+
/// Retrieve a substitution-result type.
QualType ASTContext::getSubstTemplateTypeParmType(
QualType Replacement, Decl *AssociatedDecl, unsigned Index,
@@ -9335,6 +9375,7 @@ void ASTContext::getObjCEncodingForTypeImpl(QualType T, std::string &S,
return;
case Type::HLSLAttributedResource:
+ case Type::HLSLInlineSpirv:
llvm_unreachable("unexpected type");
case Type::ArrayParameter:
@@ -11763,6 +11804,22 @@ QualType ASTContext::mergeTypes(QualType LHS, QualType RHS, bool OfBlockPointer,
return LHS;
return {};
}
+ case Type::HLSLInlineSpirv:
+ const HLSLInlineSpirvType *LHSTy = LHS->castAs<HLSLInlineSpirvType>();
+ const HLSLInlineSpirvType *RHSTy = RHS->castAs<HLSLInlineSpirvType>();
+
+ if (LHSTy->getOpcode() == RHSTy->getOpcode() &&
+ LHSTy->getSize() == RHSTy->getSize() &&
+ LHSTy->getAlignment() == RHSTy->getAlignment()) {
+ for (size_t I = 0; I < LHSTy->getOperands().size(); I++) {
+ if (LHSTy->getOperands()[I] != RHSTy->getOperands()[I]) {
+ return {};
+ }
+ }
+
+ return LHS;
+ }
+ return {};
}
llvm_unreachable("Invalid Type::Class!");
@@ -13746,6 +13803,7 @@ static QualType getCommonNonSugarTypeNode(ASTContext &Ctx, const Type *X,
SUGAR_FREE_TYPE(SubstTemplateTypeParmPack)
SUGAR_FREE_TYPE(UnresolvedUsing)
SUGAR_FREE_TYPE(HLSLAttributedResource)
+ SUGAR_FREE_TYPE(HLSLInlineSpirv)
#undef SUGAR_FREE_TYPE
#define NON_UNIQUE_TYPE(Class) UNEXPECTED_TYPE(Class, "non-unique")
NON_UNIQUE_TYPE(TypeOfExpr)
@@ -14089,6 +14147,7 @@ static QualType getCommonSugarTypeNode(ASTContext &Ctx, const Type *X,
CANONICAL_TYPE(FunctionProto)
CANONICAL_TYPE(IncompleteArray)
CANONICAL_TYPE(HLSLAttributedResource)
+ CANONICAL_TYPE(HLSLInlineSpirv)
CANONICAL_TYPE(LValueReference)
CANONICAL_TYPE(ObjCInterface)
CANONICAL_TYPE(ObjCObject)
diff --git a/clang/lib/AST/ASTImporter.cpp b/clang/lib/AST/ASTImporter.cpp
index 81acb013b0f7d..8afb29cef24a4 100644
--- a/clang/lib/AST/ASTImporter.cpp
+++ b/clang/lib/AST/ASTImporter.cpp
@@ -1832,6 +1832,48 @@ ExpectedType clang::ASTNodeImporter::VisitHLSLAttributedResourceType(
ToWrappedType, ToContainedType, ToAttrs);
}
+ExpectedType clang::ASTNodeImporter::VisitHLSLInlineSpirvType(
+ const clang::HLSLInlineSpirvType *T) {
+ Error Err = Error::success();
+
+ uint32_t ToOpcode = T->getOpcode();
+ uint32_t ToSize = T->getSize();
+ uint32_t ToAlignment = T->getAlignment();
+
+ size_t NumOperands = T->getOperands().size();
+
+ llvm::SmallVector<SpirvOperand> ToOperands;
+
+ size_t I = 0;
+ for (auto &Operand : T->getOperands()) {
+ using SpirvOperandKind = SpirvOperand::SpirvOperandKind;
+
+ switch (Operand.getKind()) {
+ case SpirvOperandKind::kConstantId:
+ ToOperands.push_back(SpirvOperand::createConstant(
+ importChecked(Err, Operand.getResultType()), Operand.getValue()));
+ break;
+ case SpirvOperandKind::kLiteral:
+ ToOperands.push_back(SpirvOperand::createLiteral(Operand.getValue()));
+ break;
+ case SpirvOperandKind::kTypeId:
+ ToOperands.push_back(SpirvOperand::createType(
+ importChecked(Err, Operand.getResultType())));
+ break;
+ default:
+ llvm_unreachable("Invalid SpirvOperand kind");
+ }
+
+ if (Err)
+ return std::move(Err);
+ }
+
+ assert(I == NumOperands);
+
+ return Importer.getToContext().getHLSLInlineSpirvType(
+ ToOpcode, ToSize, ToAlignment, ToOperands);
+}
+
ExpectedType clang::ASTNodeImporter::VisitConstantMatrixType(
const clang::ConstantMatrixType *T) {
ExpectedType ToElementTypeOrErr = import(T->getElementType());
diff --git a/clang/lib/AST/ASTStructuralEquivalence.cpp b/clang/lib/AST/ASTStructuralEquivalence.cpp
index c769722521d9c..f213368f3e1cc 100644
--- a/clang/lib/AST/ASTStructuralEquivalence.cpp
+++ b/clang/lib/AST/ASTStructuralEquivalence.cpp
@@ -1119,6 +1119,23 @@ static bool IsStructurallyEquivalent(StructuralEquivalenceContext &Context,
return false;
break;
+ case Type::HLSLInlineSpirv:
+ if (cast<HLSLInlineSpirvType>(T1)->getOpcode() !=
+ cast<HLSLInlineSpirvType>(T2)->getOpcode() ||
+ cast<HLSLInlineSpirvType>(T1)->getSize() !=
+ cast<HLSLInlineSpirvType>(T2)->getSize() ||
+ cast<HLSLInlineSpirvType>(T1)->getAlignment() !=
+ cast<HLSLInlineSpirvType>(T2)->getAlignment())
+ return false;
+ for (size_t I = 0; I < cast<HLSLInlineSpirvType>(T1)->getOperands().size();
+ I++) {
+ if (cast<HLSLInlineSpirvType>(T1)->getOperands()[I] !=
+ cast<HLSLInlineSpirvType>(T2)->getOperands()[I]) {
+ return false;
+ }
+ }
+ break;
+
case Type::Paren:
if (!IsStructurallyEquivalent(Context, cast<ParenType>(T1)->getInnerType(),
cast<ParenType>(T2)->getInnerType()))
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 80ece3c4ed7e1..96055b03ccd73 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -12439,6 +12439,7 @@ GCCTypeClass EvaluateBuiltinClassifyType(QualType T,
case Type::ObjCObjectPointer:
case Type::Pipe:
case Type::HLSLAttributedResource:
+ case Type::HLSLInlineSpirv:
// Classify all other types that don't fit into the regular
// classification the same way.
return GCCTypeClass::None;
diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp
index b81981606866a..fe72305cd7535 100644
--- a/clang/lib/AST/ItaniumMangle.cpp
+++ b/clang/lib/AST/ItaniumMangle.cpp
@@ -2453,6 +2453,7 @@ bool CXXNameMangler::mangleUnresolvedTypeOrSimpleId(QualType Ty,
case Type::Attributed:
case Type::BTFTagAttributed:
case Type::HLSLAttributedResource:
+ case Type::HLSLInlineSpirv:
case Type::Auto:
case Type::DeducedTemplateSpecialization:
case Type::PackExpansion:
@@ -4654,6 +4655,44 @@ void CXXNameMangler::mangleType(const HLSLAttributedResourceType *T) {
mangleType(T->getWrappedType());
}
+void CXXNameMangler::mangleType(const HLSLInlineSpirvType *T) {
+ SmallString<20> TypeNameStr;
+ llvm::raw_svector_ostream TypeNameOS(TypeNameStr);
+
+ TypeNameOS << "spirv_type";
+
+ TypeNameOS << "_" << T->getOpcode();
+ TypeNameOS << "_" << T->getSize();
+ TypeNameOS << "_" << T->getAlignment();
+
+ mangleVendorType(TypeNameStr);
+
+ for (auto &Operand : T->getOperands()) {
+ using SpirvOperandKind = SpirvOperand::SpirvOperandKind;
+
+ switch (Operand.getKind()) {
+ case SpirvOperandKind::kConstantId:
+ mangleVendorQualifier("_Const");
+ mangleIntegerLiteral(Operand.getResultType(),
+ llvm::APSInt(Operand.getValue()));
+ break;
+ case SpirvOperandKind::kLiteral:
+ mangleVendorQualifier("_Lit");
+ mangleIntegerLiteral(Context.getASTContext().IntTy,
+ llvm::APSInt(Operand.getValue()));
+ break;
+ case SpirvOperandKind::kTypeId:
+ mangleVendorQualifier("_Type");
+ mangleType(Operand.getResultType());
+ break;
+ default:
+ llvm_unreachable("Invalid SpirvOperand kind");
+ break;
+ }
+ TypeNameOS << Operand.getKind();
+ }
+}
+
void CXXNameMangler::mangleIntegerLiteral(QualType T,
const llvm::APSInt &Value) {
// <expr-primary> ::= L <type> <value number> E # integer literal
@@ -4667,7 +4706,6 @@ void CXXNameMangler::mangleIntegerLiteral(QualType T,
mangleNumber(Value);
}
Out << 'E';
-
}
void CXXNameMangler::mangleMemberExprBase(const Expr *Base, bool IsArrow) {
diff --git a/clang/lib/AST/MicrosoftMangle.cpp b/clang/lib/AST/MicrosoftMangle.cpp
index 7e964124a9fec..f6b5937621154 100644
--- a/clang/lib/AST/MicrosoftMangle.cpp
+++ b/clang/lib/AST/MicrosoftMangle.cpp
@@ -3739,6 +3739,11 @@ void MicrosoftCXXNameMangler::mangleType(const HLSLAttributedResourceType *T,
llvm_unreachable("HLSL uses Itanium name mangling");
}
+void MicrosoftCXXNameMangler::mangleType(const HLSLInlineSpirvType *T,
+ Qualifiers, SourceRange Range) {
+ llvm_unreachable("HLSL uses Itanium name mangling");
+}
+
// <this-adjustment> ::= <no-adjustment> | <static-adjustment> |
// <virtual-adjustment>
// <no-adjustment> ::= A # private near
diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp
index 667ffc0e599a6..f9a6ccdb7bc6b 100644
--- a/clang/lib/AST/Type.cpp
+++ b/clang/lib/AST/Type.cpp
@@ -4654,6 +4654,8 @@ static CachedProperties computeCachedProperties(const Type *T) {
return Cache::get(cast<PipeType>(T)->getElementType());
case Type::HLSLAttributedResource:
return Cache::get(cast<HLSLAttributedResourceType>(T)->getWrappedType());
+ case Type::HLSLInlineSpirv:
+ return CachedProperties(Linkage::External, false);
}
llvm_unreachable("unhandled type class");
@@ -4748,6 +4750,17 @@ LinkageInfo LinkageComputer::computeTypeLinkageInfo(const Type *T) {
return computeTypeLinkageInfo(cast<HLSLAttributedResourceType>(T)
->getContainedType()
->getCanonicalTypeInternal());
+ case Type::HLSLInlineSpirv:
+ return LinkageInfo::external();
+ {
+ const auto *ST = cast<HLSLInlineSpirvType>(T);
+ LinkageInfo LV = LinkageInfo::external();
+ for (auto &Operand : ST->getOperands()) {
+ if (Operand.isConstant() || Operand.isType())
+ LV.merge(computeTypeLinkageInfo(Operand.getResultType()));
+ }
+ return LV;
+ }
}
llvm_unreachable("unhandled type class");
@@ -4938,6 +4951,7 @@ bool Type::canHaveNullability(bool ResultIfUnknown) const {
case Type::DependentBitInt:
case Type::ArrayParameter:
case Type::HLSLAttributedResource:
+ case Type::HLSLInlineSpirv:
return false;
}
llvm_unreachable("bad type kind!");
diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp
index 4ec252e3f89b5..01fd90c40e4a5 100644
--- a/clang/lib/AST/TypePrinter.cpp
+++ b/clang/lib/AST/TypePrinter.cpp
@@ -247,6 +247,7 @@ bool TypePrinter::canPrefixQualifiers(const Type *T,
case Type::DependentBitInt:
case Type::BTFTagAttributed:
case Type::HLSLAttributedResource:
+ case Type::HLSLInlineSpirv:
CanPrefixQualifiers = true;
break;
@@ -2135,6 +2136,53 @@ void TypePrinter::printHLSLAttributedResourceAfter(
}
}
+void TypePrinter::printHLSLInlineSpirvBefore(const HLSLInlineSpirvType *T,
+ raw_ostream &OS) {
+ OS << "__hlsl_spirv_type<" << T->getOpcode();
+
+ OS << ", " << T->getSize();
+ OS << ", " << T->getAlignment();
+
+ for (auto &Operand : T->getOperands()) {
+ using SpirvOperandKind = SpirvOperand::SpirvOperandKind;
+
+ OS << ", ";
+ switch (Operand.getKind()) {
+ case SpirvOperandKind::kConstantId: {
+ QualType ConstantType = Operand.getResultType();
+ OS << "vk::integral_constant<";
+ printBefore(ConstantType, OS);
+ printAfter(ConstantType, OS);
+ OS << ", ";
+ OS << Operand.getValue();
+ OS << ">";
+ break;
+ }
+ case SpirvOperandKind::kLiteral:
+ OS << "vk::Literal<vk::integral_constant<uint, ";
+ OS << Operand.getValue();
+ OS << ">>";
+ break;
+ case SpirvOperandKind::kTypeId: {
+ QualType Type = Operand.getResultType();
+ printBefore(Type, OS);
+ printAfter(Type, OS);
+ break;
+ }
+ default:
+ llvm_unreachable("Invalid SpirvOperand kind!");
+ break;
+ }
+ }
+
+ OS << ">";
+}
+
+void TypePrinter::printHLSLInlineSpirvAfter(const HLSLInlineSpirvType *T,
+ raw_ostream &OS) {
+ // nothing to do
+}
+
void TypePrinter::printObjCInterfaceBefore(const ObjCInterfaceType *T,
raw_ostream &OS) {
OS << T->getDecl()->getName();
diff --git a/clang/lib/CodeGen/CGDebugInfo.cpp b/clang/lib/CodeGen/CGDebugInfo.cpp
index 52aa956121d73..179d93b2c3672 100644
--- a/clang/lib/CodeGen/CGDebugInfo.cpp
+++ b/clang/lib/CodeGen/CGDebugInfo.cpp
@@ -3521,6 +3521,12 @@ llvm::DIType *CGDebugInfo::CreateType(const HLSLAttributedResourceType *Ty,
return getOrCreateType(Ty->getWrappedType(), U);
}
+llvm::DIType *CGDebugInfo::CreateType(const HLSLInlineSpirvType *Ty,
+ llvm::DIFile *U) {
+ // Debug information unneeded.
+ return nullptr;
+}
+
llvm::DIType *CGDebugInfo::CreateEnumType(const EnumType *Ty) {
const EnumDecl *ED = Ty->getDecl();
@@ -3874,6 +3880,8 @@ llvm::DIType *CGDebugInfo::CreateTypeNode(QualType Ty, llvm::DIFile *Unit) {
return CreateType(cast<TemplateSpecializationType>(Ty), Unit);
case Type::HLSLAttributedResource:
return CreateType(cast<HLSLAttributedResourceType>(Ty), Unit);
+ case Type::HLSLInlineSpirv:
+ return CreateType(cast<HLSLInlineSpirvType>(Ty), Unit);
case Type::CountAttributed:
case Type::Auto:
diff --git a/clang/lib/CodeGen/CGDebugInfo.h b/clang/lib/CodeGen/CGDebugInfo.h
index b287ce7b92eee..7a63fa4b00278 100644
--- a/clang/lib/CodeGen/CGDebugInfo.h
+++ b/clang/lib/CodeGen/CGDebugInfo.h
@@ -198,6 +198,7 @@ class CGDebugInfo {
llvm::DIType *CreateType(const FunctionType *Ty, llvm::DIFile *F);
llvm::DIType *CreateType(const HLSLAttributedResourceType *Ty,
llvm::DIFile *F);
+ llvm::DIType *CreateType(const HLSLInlineSpirvType *Ty, llvm::DIFile *F);
/// Get structure or union type.
llvm::DIType *CreateType(const RecordType *Tyg);
diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp
index dcf523f56bf1e..0fcf434e1d033 100644
--- a/clang/lib/CodeGen/CodeGenFunction.cpp
+++ b/clang/lib/CodeGen/CodeGenFunction.cpp
@@ -283,6 +283,7 @@ TypeEvaluationKind CodeGenFunction::getEvaluationKind(QualType type) {
case Type::Pipe:
case Type::BitInt:
case Type::HLSLAttributedResource:
+ case Type::HLSLInlineSpirv:
return TEK_Scalar;
// Complexes.
@@ -2452,6 +2453,7 @@ void CodeGenFunction::EmitVariablyModifiedType(QualType type) {
case Type::ObjCInterface:
case Type::ObjCObjectPointer:
case Type::BitInt:
+ case Type::HLSLInlineSpirv:
llvm_unreachable("type class is never variably-modified!");
case Type::Elaborated:
diff --git a/clang/lib/CodeGen/CodeGenTypes.cpp b/clang/lib/CodeGen/CodeGenTypes.cpp
index 11cf5758b6d3a..3a22888647445 100644
--- a/clang/lib/CodeGen/CodeGenTypes.cpp
+++ b/clang/lib/CodeGen/CodeGenTypes.cpp
@@ -767,6 +767,7 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
break;
}
case Type::HLSLAttributedResource:
+ case Type::HLSLInlineSpirv:
ResultType = CGM.getHLSLRuntime().convertHLSLSpecificType(Ty);
break;
}
@@ -877,6 +878,11 @@ bool CodeGenTypes::isZeroInitializable(QualType T) {
if (const MemberPointerType *MPT = T->getAs<MemberPointerType>())
return getCXXABI().isZeroInitializable(MPT);
+ // HLSL Inline SPIR-V types are non-zero-initializable.
+ if (T->getAs<HLSLInlineSpirvType>()) {
+ return false;
+ }
+
// Everything else is okay.
return true;
}
diff --git a/clang/lib/CodeGen/ItaniumCXXABI.cpp b/clang/lib/CodeGen/ItaniumCXXABI.cpp
index 38e3a63ebfb11..71ce242b10b99 100644
--- a/clang/lib/CodeGen/ItaniumCXXABI.cpp
+++ b/clang/lib/CodeGen/ItaniumCXXABI.cpp
@@ -3959,6 +3959,7 @@ void ItaniumRTTIBuilder::BuildVTablePointer(const Type *Ty,
break;
case Type::HLSLAttributedResource:
+ case Type::HLSLInlineSpirv:
llvm_unreachable("HLSL doesn't support virtual functions");
}
@@ -4237,6 +4238,7 @@ llvm::Constant *ItaniumRTTIBuilder::BuildTypeInfo(
break;
case Type::HLSLAttributedResource:
+ case Type::HLSLInlineSpirv:
llvm_unreachable("HLSL doesn't support RTTI");
}
diff --git a/clang/lib/CodeGen/Targets/SPIR.cpp b/clang/lib/CodeGen/Targets/SPIR.cpp
index 225d9dfbd980b..1c1f243dc84c7 100644
--- a/clang/lib/CodeGen/Targets/SPIR.cpp
+++ b/clang/lib/CodeGen/Targets/SPIR.cpp
@@ -369,14 +369,102 @@ llvm::Type *CommonSPIRTargetCodeGenInfo::getOpenCLType(CodeGenModule &CGM,
return nullptr;
}
+// Gets a spirv.IntegralConstant or spirv.Literal. If IntegralType is present,
+// returns an IntegralConstant, otherwise returns a Literal.
+static llvm::Type *getInlineSpirvConstant(CodeGenModule &CGM,
+ llvm::Type *IntegralType,
+ llvm::APInt Value) {
+ llvm::LLVMContext &Ctx = CGM.getLLVMContext();
+
+ // Convert the APInt value to an array of uint32_t words
+ llvm::SmallVector<uint32_t> Words;
+
+ while (Value.ugt(0)) {
+ uint32_t Word = Value.trunc(32).getZExtValue();
+ Value.lshrInPlace(32);
+
+ Words.push_back(Word);
+ }
+ if (Words.size() == 0)
+ Words.push_back(0);
+
+ if (IntegralType) {
+ return llvm::TargetExtType::get(Ctx, "spirv.IntegralConstant",
+ {IntegralType}, Words);
+ } else {
+ return llvm::TargetExtType::get(Ctx, "spirv.Literal", {}, Words);
+ }
+}
+
+static llvm::Type *getInlineSpirvType(CodeGenModule &CGM,
+ const HLSLInlineSpirvType *SpirvType) {
+ llvm::LLVMContext &Ctx = CGM.getLLVMContext();
+
+ llvm::SmallVector<llvm::Type *> Operands;
+
+ for (auto &Operand : SpirvType->getOperands()) {
+ using SpirvOperandKind = SpirvOperand::SpirvOperandKind;
+
+ llvm::Type *Result = nullptr;
+ switch (Operand.getKind()) {
+ case SpirvOperandKind::kConstantId: {
+ llvm::Type *IntegralType =
+ CGM.getTypes().ConvertType(Operand.getResultType());
+ llvm::APInt Value = Operand.getValue();
+
+ Result = getInlineSpirvConstant(CGM, IntegralType, Value);
+ break;
+ }
+ case SpirvOperandKind::kLiteral: {
+ llvm::APInt Value = Operand.getValue();
+ Result = getInlineSpirvConstant(CGM, nullptr, Value);
+ break;
+ }
+ case SpirvOperandKind::kTypeId: {
+ QualType TypeOperand = Operand.getResultType();
+ if (auto *RT = TypeOperand->getAs<RecordType>()) {
+ auto *RD = RT->getDecl();
+ assert(RD->isCompleteDefinition() &&
+ "Type completion should have been required in Sema");
+
+ const FieldDecl *HandleField = RD->findFirstNamedDataMember();
+ if (HandleField) {
+ QualType ResourceType = HandleField->getType();
+ if (ResourceType->getAs<HLSLAttributedResourceType>()) {
+ TypeOperand = ResourceType;
+ }
+ }
+ }
+ Result = CGM.getTypes().ConvertType(TypeOperand);
+ break;
+ }
+ default:
+ llvm_unreachable("HLSLInlineSpirvType had invalid operand!");
+ break;
+ }
+
+ assert(Result);
+ Operands.push_back(Result);
+ }
+
+ return llvm::TargetExtType::get(Ctx, "spirv.Type", Operands,
+ {SpirvType->getOpcode(), SpirvType->getSize(),
+ SpirvType->getAlignment()});
+}
+
llvm::Type *CommonSPIRTargetCodeGenInfo::getHLSLType(
CodeGenModule &CGM, const Type *Ty,
const SmallVector<int32_t> *Packoffsets) const {
+ llvm::LLVMContext &Ctx = CGM.getLLVMContext();
+
+ if (auto *SpirvType = dyn_cast<HLSLInlineSpirvType>(Ty)) {
+ return getInlineSpirvType(CGM, SpirvType);
+ }
+
auto *ResType = dyn_cast<HLSLAttributedResourceType>(Ty);
if (!ResType)
return nullptr;
- llvm::LLVMContext &Ctx = CGM.getLLVMContext();
const HLSLAttributedResourceType::Attributes &ResAttrs = ResType->getAttrs();
switch (ResAttrs.ResourceClass) {
case llvm::dxil::ResourceClass::UAV:
diff --git a/clang/lib/Headers/CMakeLists.txt b/clang/lib/Headers/CMakeLists.txt
index acf49e40c447e..e9555cf7e0898 100644
--- a/clang/lib/Headers/CMakeLists.txt
+++ b/clang/lib/Headers/CMakeLists.txt
@@ -91,6 +91,7 @@ set(hlsl_subdir_files
hlsl/hlsl_intrinsic_helpers.h
hlsl/hlsl_intrinsics.h
hlsl/hlsl_detail.h
+ hlsl/hlsl_spirv.h
)
set(hlsl_files
${hlsl_h}
diff --git a/clang/lib/Headers/hlsl.h b/clang/lib/Headers/hlsl.h
index b494b4d0f78bb..684d29d5ed55b 100644
--- a/clang/lib/Headers/hlsl.h
+++ b/clang/lib/Headers/hlsl.h
@@ -27,6 +27,10 @@
#endif
#include "hlsl/hlsl_intrinsics.h"
+#ifdef __spirv__
+#include "hlsl/hlsl_spirv.h"
+#endif // __spirv__
+
#if defined(__clang__)
#pragma clang diagnostic pop
#endif
diff --git a/clang/lib/Headers/hlsl/hlsl_spirv.h b/clang/lib/Headers/hlsl/hlsl_spirv.h
new file mode 100644
index 0000000000000..8a71699a4ed5c
--- /dev/null
+++ b/clang/lib/Headers/hlsl/hlsl_spirv.h
@@ -0,0 +1,30 @@
+//===----- hlsl_spirv.h - HLSL definitions for SPIR-V target --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef _HLSL_HLSL_SPIRV_H_
+#define _HLSL_HLSL_SPIRV_H_
+
+namespace hlsl {
+ namespace vk {
+ // template <class T> using Foo = __hlsl_spirv_t;
+ // typedef Foo
+ template <typename T, T v> struct integral_constant {
+ static constexpr T value = v;
+ };
+
+ template <typename T> struct Literal {};
+
+ template <uint Opcode, uint Size, uint Alignment, typename... Operands>
+ using SpirvType = __hlsl_spirv_type<Opcode, Size, Alignment, Operands...>;
+
+ template <uint Opcode, typename... Operands>
+ using SpirvOpaqueType = __hlsl_spirv_type<Opcode, 0, 0, Operands...>;
+ } // namespace vk
+ } // namespace hlsl
+
+#endif // _HLSL_HLSL_SPIRV_H_
\ No newline at end of file
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index e7f418ae6802e..b27c884bb0d8c 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -4478,6 +4478,7 @@ static void captureVariablyModifiedType(ASTContext &Context, QualType T,
case Type::ObjCTypeParam:
case Type::Pipe:
case Type::BitInt:
+ case Type::HLSLInlineSpirv:
llvm_unreachable("type class is never variably-modified!");
case Type::Elaborated:
T = cast<ElaboratedType>(Ty)->getNamedType();
diff --git a/clang/lib/Sema/SemaLookup.cpp b/clang/lib/Sema/SemaLookup.cpp
index a77ca779a9ee3..a1f5d673c1b5a 100644
--- a/clang/lib/Sema/SemaLookup.cpp
+++ b/clang/lib/Sema/SemaLookup.cpp
@@ -923,13 +923,25 @@ bool Sema::LookupBuiltin(LookupResult &R) {
NameKind == Sema::LookupRedeclarationWithLinkage) {
IdentifierInfo *II = R.getLookupName().getAsIdentifierInfo();
if (II) {
- if (getLangOpts().CPlusPlus && NameKind == Sema::LookupOrdinaryName) {
-#define BuiltinTemplate(BIName) \
+ if (NameKind == Sema::LookupOrdinaryName) {
+ if (getLangOpts().CPlusPlus) {
+#define BuiltinTemplate(BIName)
+#define CPlusPlusBuiltinTemplate(BIName) \
if (II == getASTContext().get##BIName##Name()) { \
R.addDecl(getASTContext().get##BIName##Decl()); \
return true; \
}
#include "clang/Basic/BuiltinTemplates.inc"
+ }
+ if (getLangOpts().HLSL) {
+#define BuiltinTemplate(BIName)
+#define HLSLBuiltinTemplate(BIName) \
+ if (II == getASTContext().get##BIName##Name()) { \
+ R.addDecl(getASTContext().get##BIName##Decl()); \
+ return true; \
+ }
+#include "clang/Basic/BuiltinTemplates.inc"
+ }
}
// Check if this is an OpenCL Builtin, and if so, insert its overloads.
@@ -3265,6 +3277,11 @@ addAssociatedClassesAndNamespaces(AssociatedLookup &Result, QualType Ty) {
case Type::HLSLAttributedResource:
T = cast<HLSLAttributedResourceType>(T)->getWrappedType().getTypePtr();
+ break;
+
+ // Inline SPIR-V types are treated as fundamental types.
+ case Type::HLSLInlineSpirv:
+ break;
}
if (Queue.empty())
diff --git a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp
index 1f87ef4b27bab..baab78327dc6d 100644
--- a/clang/lib/Sema/SemaTemplate.cpp
+++ b/clang/lib/Sema/SemaTemplate.cpp
@@ -3228,6 +3228,62 @@ static QualType builtinCommonTypeImpl(Sema &S, TemplateName BaseTemplate,
}
}
+static bool isInVkNamespace(const RecordType *RT) {
+ DeclContext *DC = RT->getDecl()->getDeclContext();
+ if (!DC)
+ return false;
+
+ NamespaceDecl *ND = dyn_cast<NamespaceDecl>(DC);
+ if (!ND)
+ return false;
+
+ return ND->getQualifiedNameAsString() == "hlsl::vk";
+}
+
+static SpirvOperand checkHLSLSpirvTypeOperand(Sema &SemaRef,
+ QualType OperandArg,
+ SourceLocation Loc) {
+ if (auto *RT = OperandArg->getAs<RecordType>()) {
+ bool Literal = false;
+ SourceLocation LiteralLoc;
+ if (isInVkNamespace(RT) && RT->getDecl()->getName() == "Literal") {
+ auto SpecDecl = dyn_cast<ClassTemplateSpecializationDecl>(RT->getDecl());
+ assert(SpecDecl);
+
+ const TemplateArgumentList &LiteralArgs = SpecDecl->getTemplateArgs();
+ QualType ConstantType = LiteralArgs[0].getAsType();
+ RT = ConstantType->getAs<RecordType>();
+ Literal = true;
+ LiteralLoc = SpecDecl->getSourceRange().getBegin();
+ }
+
+ if (RT && isInVkNamespace(RT) &&
+ RT->getDecl()->getName() == "integral_constant") {
+ auto SpecDecl = dyn_cast<ClassTemplateSpecializationDecl>(RT->getDecl());
+ assert(SpecDecl);
+
+ const TemplateArgumentList &ConstantArgs = SpecDecl->getTemplateArgs();
+
+ QualType ConstantType = ConstantArgs[0].getAsType();
+ llvm::APInt Value = ConstantArgs[1].getAsIntegral();
+
+ if (Literal) {
+ return SpirvOperand::createLiteral(Value);
+ } else {
+ return SpirvOperand::createConstant(ConstantType, Value);
+ }
+ } else if (Literal) {
+ SemaRef.Diag(LiteralLoc, diag::err_hlsl_vk_literal_must_contain_constant);
+ return SpirvOperand();
+ }
+ }
+ if (SemaRef.RequireCompleteType(Loc, OperandArg,
+ diag::err_call_incomplete_argument)) {
+ return SpirvOperand();
+ }
+ return SpirvOperand::createType(OperandArg);
+}
+
static QualType
checkBuiltinTemplateIdType(Sema &SemaRef, BuiltinTemplateDecl *BTD,
ArrayRef<TemplateArgument> Converted,
@@ -3289,7 +3345,7 @@ checkBuiltinTemplateIdType(Sema &SemaRef, BuiltinTemplateDecl *BTD,
// __type_pack_element<Index, T_1, ..., T_N>
// are treated like T_Index.
assert(Converted.size() == 2 &&
- "__type_pack_element should be given an index and a parameter pack");
+ "__type_pack_element should be given an index and a parameter pack");
TemplateArgument IndexArg = Converted[0], Ts = Converted[1];
if (IndexArg.isDependent() || Ts.isDependent())
@@ -3332,6 +3388,39 @@ checkBuiltinTemplateIdType(Sema &SemaRef, BuiltinTemplateDecl *BTD,
}
return HasNoTypeMember;
}
+
+ case BTK__hlsl_spirv_type: {
+ assert(Converted.size() == 4);
+
+ if (!Context.getTargetInfo().getTriple().isSPIRV()) {
+ SemaRef.Diag(TemplateLoc, diag::err_hlsl_spirv_only)
+ << "__hlsl_spirv_type";
+ }
+
+ if (llvm::any_of(Converted, [](auto &C) { return C.isDependent(); }))
+ return Context.getCanonicalTemplateSpecializationType(TemplateName(BTD),
+ Converted);
+
+ uint64_t Opcode = Converted[0].getAsIntegral().getZExtValue();
+ uint64_t Size = Converted[1].getAsIntegral().getZExtValue();
+ uint64_t Alignment = Converted[2].getAsIntegral().getZExtValue();
+
+ ArrayRef<TemplateArgument> OperandArgs = Converted[3].getPackAsArray();
+
+ llvm::SmallVector<SpirvOperand> Operands;
+
+ for (auto &OperandTA : OperandArgs) {
+ QualType OperandArg = OperandTA.getAsType();
+ auto Operand = checkHLSLSpirvTypeOperand(SemaRef, OperandArg,
+ TemplateArgs[3].getLocation());
+ if (!Operand.isValid()) {
+ return QualType();
+ }
+ Operands.push_back(Operand);
+ }
+
+ return Context.getHLSLInlineSpirvType(Opcode, Size, Alignment, Operands);
+ }
}
llvm_unreachable("unexpected BuiltinTemplateDecl!");
}
@@ -6165,6 +6254,18 @@ bool UnnamedLocalNoLinkageFinder::VisitHLSLAttributedResourceType(
return Visit(T->getWrappedType());
}
+bool UnnamedLocalNoLinkageFinder::VisitHLSLInlineSpirvType(
+ const HLSLInlineSpirvType *T) {
+ for (auto &Operand : T->getOperands()) {
+ if (Operand.isConstant() && Operand.isLiteral()) {
+ if (Visit(Operand.getResultType())) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
bool Sema::CheckTemplateArgument(TypeSourceInfo *ArgInfo) {
assert(ArgInfo && "invalid TypeSourceInfo");
QualType Arg = ArgInfo->getType();
diff --git a/clang/lib/Sema/SemaTemplateDeduction.cpp b/clang/lib/Sema/SemaTemplateDeduction.cpp
index 9969f1762fe36..147e078539a5d 100644
--- a/clang/lib/Sema/SemaTemplateDeduction.cpp
+++ b/clang/lib/Sema/SemaTemplateDeduction.cpp
@@ -2492,6 +2492,7 @@ static TemplateDeductionResult DeduceTemplateArgumentsByTypeMatch(
case Type::Pipe:
case Type::ArrayParameter:
case Type::HLSLAttributedResource:
+ case Type::HLSLInlineSpirv:
// No template argument deduction for these types
return TemplateDeductionResult::Success;
@@ -7116,6 +7117,7 @@ MarkUsedTemplateParameters(ASTContext &Ctx, QualType T,
case Type::UnresolvedUsing:
case Type::Pipe:
case Type::BitInt:
+ case Type::HLSLInlineSpirv:
#define TYPE(Class, Base)
#define ABSTRACT_TYPE(Class, Base)
#define DEPENDENT_TYPE(Class, Base)
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index 2df961a48c7c3..e3070675fa2c1 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -5874,6 +5874,7 @@ namespace {
Visit(TL.getWrappedLoc());
fillHLSLAttributedResourceTypeLoc(TL, State);
}
+ void VisitHLSLInlineSpirvTypeLoc(HLSLInlineSpirvTypeLoc TL) {}
void VisitMacroQualifiedTypeLoc(MacroQualifiedTypeLoc TL) {
Visit(TL.getInnerLoc());
TL.setExpansionLoc(
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index 916b8e2735cd0..f3bbaf78ceddf 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -7652,6 +7652,13 @@ QualType TreeTransform<Derived>::TransformHLSLAttributedResourceType(
return Result;
}
+template <typename Derived>
+QualType TreeTransform<Derived>::TransformHLSLInlineSpirvType(
+ TypeLocBuilder &TLB, HLSLInlineSpirvTypeLoc TL) {
+ // No transformations needed.
+ return TL.getType();
+}
+
template<typename Derived>
QualType
TreeTransform<Derived>::TransformParenType(TypeLocBuilder &TLB,
diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp
index 58a57d6c54523..7b2bb7b00fbff 100644
--- a/clang/lib/Serialization/ASTReader.cpp
+++ b/clang/lib/Serialization/ASTReader.cpp
@@ -7284,6 +7284,10 @@ void TypeLocReader::VisitHLSLAttributedResourceTypeLoc(
// Nothing to do.
}
+void TypeLocReader::VisitHLSLInlineSpirvTypeLoc(HLSLInlineSpirvTypeLoc TL) {
+ // Nothing to do.
+}
+
void TypeLocReader::VisitTemplateTypeParmTypeLoc(TemplateTypeParmTypeLoc TL) {
TL.setNameLoc(readSourceLocation());
}
@@ -9753,6 +9757,11 @@ TypeCoupledDeclRefInfo ASTRecordReader::readTypeCoupledDeclRefInfo() {
return TypeCoupledDeclRefInfo(readDeclAs<ValueDecl>(), readBool());
}
+SpirvOperand ASTRecordReader::readHLSLSpirvOperand() {
+ return SpirvOperand(SpirvOperand::SpirvOperandKind(readInt()), readQualType(),
+ readAPInt());
+}
+
void ASTRecordReader::readQualifierInfo(QualifierInfo &Info) {
Info.QualifierLoc = readNestedNameSpecifierLoc();
unsigned NumTPLists = readInt();
diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp
index 84f7f2bc5fce4..ac6647632c4ea 100644
--- a/clang/lib/Serialization/ASTWriter.cpp
+++ b/clang/lib/Serialization/ASTWriter.cpp
@@ -604,6 +604,10 @@ void TypeLocWriter::VisitHLSLAttributedResourceTypeLoc(
// Nothing to do.
}
+void TypeLocWriter::VisitHLSLInlineSpirvTypeLoc(HLSLInlineSpirvTypeLoc TL) {
+ // Nothing to do.
+}
+
void TypeLocWriter::VisitTemplateTypeParmTypeLoc(TemplateTypeParmTypeLoc TL) {
addSourceLocation(TL.getNameLoc());
}
diff --git a/clang/test/AST/HLSL/Inputs/pch_spirv_type.hlsl b/clang/test/AST/HLSL/Inputs/pch_spirv_type.hlsl
new file mode 100644
index 0000000000000..326c75dbc3bbe
--- /dev/null
+++ b/clang/test/AST/HLSL/Inputs/pch_spirv_type.hlsl
@@ -0,0 +1,6 @@
+
+float2 foo(float2 a, float2 b) {
+ return a + b;
+}
+
+vk::SpirvOpaqueType</* OpTypeArray */ 28, RWBuffer<float>, vk::integral_constant<uint, 4>> buffers;
diff --git a/clang/test/AST/HLSL/ast-dump-SpirvType.hlsl b/clang/test/AST/HLSL/ast-dump-SpirvType.hlsl
new file mode 100644
index 0000000000000..f9aaf368ac935
--- /dev/null
+++ b/clang/test/AST/HLSL/ast-dump-SpirvType.hlsl
@@ -0,0 +1,27 @@
+// RUN: %clang_cc1 -finclude-default-header -triple spirv-unknown-vulkan-compute -x hlsl -ast-dump -o - %s | FileCheck %s
+
+// CHECK: TypedefDecl 0x{{.+}} <{{.+}}:4:1, col:83> col:83 referenced AType 'vk::SpirvOpaqueType<123, RWBuffer<float>, vk::integral_constant<uint, 4>>':'__hlsl_spirv_type<123, 0, 0, RWBuffer<float>, vk::integral_constant<unsigned int, 4>>'
+typedef vk::SpirvOpaqueType<123, RWBuffer<float>, vk::integral_constant<uint, 4>> AType;
+// CHECK: TypedefDecl 0x{{.+}} <{{.+}}:6:1, col:133> col:133 referenced BType 'vk::SpirvType<12, 2, 4, vk::integral_constant<uint64_t, 4886718345L>, float, vk::Literal<vk::integral_constant<uint, 456>>>':'__hlsl_spirv_type<12, 2, 4, vk::integral_constant<unsigned long, 4886718345>, float, vk::Literal<vk::integral_constant<uint, 456>>>'
+typedef vk::SpirvType<12, 2, 4, vk::integral_constant<uint64_t, 0x123456789>, float, vk::Literal<vk::integral_constant<uint, 456>>> BType;
+
+// CHECK: VarDecl 0x{{.+}} <{{.+}}:9:1, col:7> col:7 AValue 'hlsl_constant AType':'hlsl_constant __hlsl_spirv_type<123, 0, 0, RWBuffer<float>, vk::integral_constant<unsigned int, 4>>'
+AType AValue;
+// CHECK: VarDecl 0x{{.+}} <{{.+}}:11:1, col:7> col:7 BValue 'hlsl_constant BType':'hlsl_constant __hlsl_spirv_type<12, 2, 4, vk::integral_constant<unsigned long, 4886718345>, float, vk::Literal<vk::integral_constant<uint, 456>>>'
+BType BValue;
+
+// CHECK: VarDecl 0x{{.+}} <{{.+}}:14:1, col:80> col:80 CValue 'hlsl_constant vk::SpirvOpaqueType<123, vk::Literal<vk::integral_constant<uint, 305419896>>>':'hlsl_constant __hlsl_spirv_type<123, 0, 0, vk::Literal<vk::integral_constant<uint, 305419896>>>'
+vk::SpirvOpaqueType<123, vk::Literal<vk::integral_constant<uint, 0x12345678>>> CValue;
+
+// CHECK: TypeAliasDecl 0x{{.+}} <{{.+}}:18:1, col:72> col:7 Array 'vk::SpirvOpaqueType<28, T, vk::integral_constant<uint, L>>':'__hlsl_spirv_type<28, 0, 0, type-parameter-0-0, integral_constant<unsigned int, L>>'
+template <class T, uint L>
+using Array = vk::SpirvOpaqueType<28, T, vk::integral_constant<uint, L>>;
+
+// CHECK: VarDecl 0x{{.+}} <{{.+}}:21:1, col:16> col:16 DValue 'hlsl_constant Array<uint, 5>':'hlsl_constant __hlsl_spirv_type<28, 0, 0, uint, vk::integral_constant<unsigned int, 5>>'
+Array<uint, 5> DValue;
+
+[numthreads(1, 1, 1)]
+void main() {
+// CHECK: VarDecl 0x{{.+}} <col:5, col:11> col:11 EValue 'AType':'__hlsl_spirv_type<123, 0, 0, RWBuffer<float>, vk::integral_constant<unsigned int, 4>>'
+ AType EValue;
+}
diff --git a/clang/test/AST/HLSL/pch_spirv_type.hlsl b/clang/test/AST/HLSL/pch_spirv_type.hlsl
new file mode 100644
index 0000000000000..045f89a1b8461
--- /dev/null
+++ b/clang/test/AST/HLSL/pch_spirv_type.hlsl
@@ -0,0 +1,17 @@
+// RUN: %clang_cc1 -triple spirv-unknown-vulkan-library -x hlsl \
+// RUN: -finclude-default-header -emit-pch -o %t %S/Inputs/pch_spirv_type.hlsl
+// RUN: %clang_cc1 -triple spirv-unknown-vulkan-library -x hlsl \
+// RUN: -finclude-default-header -include-pch %t -ast-dump-all %s \
+// RUN: | FileCheck %s
+
+// Make sure PCH works by using function declared in PCH header and declare a SpirvType in current file.
+// CHECK:FunctionDecl 0x[[FOO:[0-9a-f]+]] <{{.*}}:2:1, line:4:1> line:2:8 imported used foo 'float2 (float2, float2)'
+// CHECK:VarDecl 0x{{[0-9a-f]+}} <{{.*}}:10:1, col:92> col:92 buffers2 'hlsl_constant vk::SpirvOpaqueType<28, RWBuffer<float>, vk::integral_constant<uint, 4>>':'hlsl_constant __hlsl_spirv_type<28, 0, 0, RWBuffer<float>, vk::integral_constant<unsigned int, 4>>'
+vk::SpirvOpaqueType</* OpTypeArray */ 28, RWBuffer<float>, vk::integral_constant<uint, 4>> buffers2;
+
+float2 bar(float2 a, float2 b) {
+// CHECK:CallExpr 0x{{[0-9a-f]+}} <col:10, col:18> 'float2':'vector<float, 2>'
+// CHECK-NEXT:ImplicitCastExpr 0x{{[0-9a-f]+}} <col:10> 'float2 (*)(float2, float2)' <FunctionToPointerDecay>
+// CHECK-NEXT:`-DeclRefExpr 0x{{[0-9a-f]+}} <col:10> 'float2 (float2, float2)' lvalue Function 0x[[FOO]] 'foo' 'float2 (float2, float2)'
+ return foo(a, b);
+}
diff --git a/clang/test/AST/HLSL/vector-alias.hlsl b/clang/test/AST/HLSL/vector-alias.hlsl
index 58d80e9b4a4e4..e1f78e6abdca8 100644
--- a/clang/test/AST/HLSL/vector-alias.hlsl
+++ b/clang/test/AST/HLSL/vector-alias.hlsl
@@ -1,53 +1,52 @@
-// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -ast-dump -o - %s | FileCheck %s
-
-// CHECK: NamespaceDecl {{.*}} implicit hlsl
-// CHECK-NEXT: TypeAliasTemplateDecl {{.*}} implicit vector
-// CHECK-NEXT: TemplateTypeParmDecl {{.*}} class depth 0 index 0 element
-// CHECK-NEXT: TemplateArgument type 'float'
-// CHECK-NEXT: BuiltinType {{.*}} 'float'
-// CHECK-NEXT: NonTypeTemplateParmDecl {{.*}} 'int' depth 0 index 1 element_count
-// CHECK-NEXT: TemplateArgument expr
-// CHECK-NEXT: IntegerLiteral {{.*}} 'int' 4
-// CHECK-NEXT: TypeAliasDecl {{.*}} implicit vector 'vector<element, element_count>'
-// CHECK-NEXT: DependentSizedExtVectorType {{.*}} 'vector<element, element_count>' dependent
-// CHECK-NEXT: TemplateTypeParmType {{.*}} 'element' dependent depth 0 index 0
-// CHECK-NEXT: TemplateTypeParm {{.*}} 'element'
-// CHECK-NEXT: DeclRefExpr {{.*}} 'int' lvalue
-// NonTypeTemplateParm {{.*}} 'element_count' 'int'
-
-// Make sure we got a using directive at the end.
-// CHECK: UsingDirectiveDecl {{.*}} Namespace {{.*}} 'hlsl'
-
-[numthreads(1,1,1)]
-int entry() {
- // Verify that the alias is generated inside the hlsl namespace.
- hlsl::vector<float, 2> Vec2 = {1.0, 2.0};
-
- // CHECK: DeclStmt
- // CHECK-NEXT: VarDecl {{.*}} Vec2 'hlsl::vector<float, 2>':'vector<float, 2>' cinit
-
- // Verify that you don't need to specify the namespace.
- vector<int, 2> Vec2a = {1, 2};
-
- // CHECK: DeclStmt
- // CHECK-NEXT: VarDecl {{.*}} Vec2a 'vector<int, 2>' cinit
-
- // Build a bigger vector.
- vector<double, 4> Vec4 = {1.0, 2.0, 3.0, 4.0};
-
- // CHECK: DeclStmt
- // CHECK-NEXT: VarDecl {{.*}} used Vec4 'vector<double, 4>' cinit
-
- // Verify that swizzles still work.
- vector<double, 3> Vec3 = Vec4.xyz;
-
- // CHECK: DeclStmt {{.*}}
- // CHECK-NEXT: VarDecl {{.*}} Vec3 'vector<double, 3>' cinit
-
- // Verify that the implicit arguments generate the correct type.
- vector<> ImpVec4 = {1.0, 2.0, 3.0, 4.0};
-
- // CHECK: DeclStmt
- // CHECK-NEXT: VarDecl {{.*}} ImpVec4 'vector<>':'vector<float, 4>' cinit
- return 1;
-}
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -ast-dump -o - %s | FileCheck %s
+// CHECK: NamespaceDecl {{.*}} implicit hlsl
+// CHECK: TypeAliasTemplateDecl {{.*}} implicit vector
+// CHECK-NEXT: TemplateTypeParmDecl {{.*}} class depth 0 index 0 element
+// CHECK-NEXT: TemplateArgument type 'float'
+// CHECK-NEXT: BuiltinType {{.*}} 'float'
+// CHECK-NEXT: NonTypeTemplateParmDecl {{.*}} 'int' depth 0 index 1 element_count
+// CHECK-NEXT: TemplateArgument expr
+// CHECK-NEXT: IntegerLiteral {{.*}} 'int' 4
+// CHECK-NEXT: TypeAliasDecl {{.*}} implicit vector 'vector<element, element_count>'
+// CHECK-NEXT: DependentSizedExtVectorType {{.*}} 'vector<element, element_count>' dependent
+// CHECK-NEXT: TemplateTypeParmType {{.*}} 'element' dependent depth 0 index 0
+// CHECK-NEXT: TemplateTypeParm {{.*}} 'element'
+// CHECK-NEXT: DeclRefExpr {{.*}} 'int' lvalue
+// NonTypeTemplateParm {{.*}} 'element_count' 'int'
+
+// Make sure we got a using directive at the end.
+// CHECK: UsingDirectiveDecl {{.*}} Namespace {{.*}} 'hlsl'
+
+[numthreads(1,1,1)]
+int entry() {
+ // Verify that the alias is generated inside the hlsl namespace.
+ hlsl::vector<float, 2> Vec2 = {1.0, 2.0};
+
+ // CHECK: DeclStmt
+ // CHECK-NEXT: VarDecl {{.*}} Vec2 'hlsl::vector<float, 2>':'vector<float, 2>' cinit
+
+ // Verify that you don't need to specify the namespace.
+ vector<int, 2> Vec2a = {1, 2};
+
+ // CHECK: DeclStmt
+ // CHECK-NEXT: VarDecl {{.*}} Vec2a 'vector<int, 2>' cinit
+
+ // Build a bigger vector.
+ vector<double, 4> Vec4 = {1.0, 2.0, 3.0, 4.0};
+
+ // CHECK: DeclStmt
+ // CHECK-NEXT: VarDecl {{.*}} used Vec4 'vector<double, 4>' cinit
+
+ // Verify that swizzles still work.
+ vector<double, 3> Vec3 = Vec4.xyz;
+
+ // CHECK: DeclStmt {{.*}}
+ // CHECK-NEXT: VarDecl {{.*}} Vec3 'vector<double, 3>' cinit
+
+ // Verify that the implicit arguments generate the correct type.
+ vector<> ImpVec4 = {1.0, 2.0, 3.0, 4.0};
+
+ // CHECK: DeclStmt
+ // CHECK-NEXT: VarDecl {{.*}} ImpVec4 'vector<>':'vector<float, 4>' cinit
+ return 1;
+}
diff --git a/clang/test/CodeGenHLSL/inline/SpirvType.alignment.hlsl b/clang/test/CodeGenHLSL/inline/SpirvType.alignment.hlsl
new file mode 100644
index 0000000000000..4cd8f2bf914aa
--- /dev/null
+++ b/clang/test/CodeGenHLSL/inline/SpirvType.alignment.hlsl
@@ -0,0 +1,16 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN: spirv-unknown-vulkan-compute %s -emit-llvm -disable-llvm-passes \
+// RUN: -o - | FileCheck %s
+
+using Int = vk::SpirvType</* OpTypeInt */ 21, 4, 64, vk::Literal<vk::integral_constant<uint, 8>>, vk::Literal<vk::integral_constant<bool, false>>>;
+
+// CHECK: %struct.S = type <{ i32, [4 x i8], target("spirv.Type", target("spirv.Literal", 8), target("spirv.Literal", 0), 21, 4, 64), [8 x i8] }>
+struct S {
+ int a;
+ Int b;
+};
+
+[numthreads(1,1,1)]
+void main() {
+ S value;
+}
diff --git a/clang/test/CodeGenHLSL/inline/SpirvType.dx.error.hlsl b/clang/test/CodeGenHLSL/inline/SpirvType.dx.error.hlsl
new file mode 100644
index 0000000000000..8c7140689ce74
--- /dev/null
+++ b/clang/test/CodeGenHLSL/inline/SpirvType.dx.error.hlsl
@@ -0,0 +1,12 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN: dxil-pc-shadermodel6.0-compute %s \
+// RUN: -fsyntax-only -verify
+
+typedef vk::SpirvType<12, 2, 4, float> InvalidType1; // expected-error {{use of undeclared identifier 'vk'}}
+vk::Literal<nullptr> Unused; // expected-error {{use of undeclared identifier 'vk'}}
+vk::integral_constant<uint, 456> Unused2; // expected-error {{use of undeclared identifier 'vk'}}
+typedef vk::SpirvOpaqueType<12, float> InvalidType2; // expected-error {{use of undeclared identifier 'vk'}}
+
+[numthreads(1, 1, 1)]
+void main() {
+}
diff --git a/clang/test/CodeGenHLSL/inline/SpirvType.hlsl b/clang/test/CodeGenHLSL/inline/SpirvType.hlsl
new file mode 100644
index 0000000000000..ea013c62899f8
--- /dev/null
+++ b/clang/test/CodeGenHLSL/inline/SpirvType.hlsl
@@ -0,0 +1,68 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN: spirv-unknown-vulkan-compute %s -emit-llvm -disable-llvm-passes \
+// RUN: -o - | FileCheck %s
+
+template<class T, uint64_t Size>
+using Array = vk::SpirvOpaqueType</* OpTypeArray */ 28, T, vk::integral_constant<uint64_t, Size>>;
+
+template<uint64_t Size>
+using ArrayBuffer = Array<RWBuffer<float>, Size>;
+
+typedef vk::SpirvType</* OpTypeInt */ 21, 4, 32, vk::Literal<vk::integral_constant<uint, 32>>, vk::Literal<vk::integral_constant<bool, false>>> Int;
+
+typedef Array<Int, 5> ArrayInt;
+
+// CHECK: %struct.S = type { target("spirv.Type", target("spirv.Image", float, 5, 2, 0, 0, 2, 0), target("spirv.IntegralConstant", i64, 4), 28, 0, 0), target("spirv.Type", target("spirv.Literal", 32), target("spirv.Literal", 0), 21, 4, 32) }
+struct S {
+ ArrayBuffer<4> b;
+ Int i;
+};
+
+// CHECK: define spir_func target("spirv.Type", target("spirv.Image", float, 5, 2, 0, 0, 2, 0), target("spirv.IntegralConstant", i64, 4), 28, 0, 0) @_Z14getArrayBufferu17spirv_type_28_0_0U5_TypeN4hlsl8RWBufferIfEEU6_ConstLm4E(target("spirv.Type", target("spirv.Image", float, 5, 2, 0, 0, 2, 0), target("spirv.IntegralConstant", i64, 4), 28, 0, 0) %v) #0
+ArrayBuffer<4> getArrayBuffer(ArrayBuffer<4> v) {
+ return v;
+}
+
+// CHECK: define spir_func target("spirv.Type", target("spirv.Literal", 32), target("spirv.Literal", 0), 21, 4, 32) @_Z6getIntu18spirv_type_21_4_32U4_LitLi32EU4_LitLi0E(target("spirv.Type", target("spirv.Literal", 32), target("spirv.Literal", 0), 21, 4, 32) %v) #0
+Int getInt(Int v) {
+ return v;
+}
+
+// TODO: uncomment and test once CBuffer handles are implemented for SPIR-V
+// ArrayBuffer<4> g_buffers;
+// Int g_word;
+
+[numthreads(1, 1, 1)]
+void main() {
+ // CHECK: %buffers = alloca target("spirv.Type", target("spirv.Image", float, 5, 2, 0, 0, 2, 0), target("spirv.IntegralConstant", i64, 4), 28, 0, 0), align 4
+ ArrayBuffer<4> buffers;
+
+ // CHECK: %longBuffers = alloca target("spirv.Type", target("spirv.Image", float, 5, 2, 0, 0, 2, 0), target("spirv.IntegralConstant", i64, 591751049, 1), 28, 0, 0), align 4
+ ArrayBuffer<0x123456789> longBuffers;
+
+ // CHECK: %word = alloca target("spirv.Type", target("spirv.Literal", 32), target("spirv.Literal", 0), 21, 4, 32), align 4
+ Int word;
+
+ // CHECK: %words = alloca [4 x target("spirv.Type", target("spirv.Literal", 32), target("spirv.Literal", 0), 21, 4, 32)], align 4
+ Int words[4];
+
+ // CHECK: %words2 = alloca target("spirv.Type", target("spirv.Type", target("spirv.Literal", 32), target("spirv.Literal", 0), 21, 4, 32), target("spirv.IntegralConstant", i64, 5), 28, 0, 0), align 4
+ ArrayInt words2;
+
+ // CHECK: %value = alloca %struct.S, align 4
+ S value;
+
+ // CHECK: %buffers2 = alloca target("spirv.Type", target("spirv.Image", float, 5, 2, 0, 0, 2, 0), target("spirv.IntegralConstant", i64, 4), 28, 0, 0), align 4
+ // CHECK: %word2 = alloca target("spirv.Type", target("spirv.Literal", 32), target("spirv.Literal", 0), 21, 4, 32), align 4
+
+
+ // CHECK: [[loaded:%[0-9]+]] = load target("spirv.Type", target("spirv.Image", float, 5, 2, 0, 0, 2, 0), target("spirv.IntegralConstant", i64, 4), 28, 0, 0), ptr %buffers, align 4
+ // CHECK: %call1 = call spir_func target("spirv.Type", target("spirv.Image", float, 5, 2, 0, 0, 2, 0), target("spirv.IntegralConstant", i64, 4), 28, 0, 0) @_Z14getArrayBufferu17spirv_type_28_0_0U5_TypeN4hlsl8RWBufferIfEEU6_ConstLm4E(target("spirv.Type", target("spirv.Image", float, 5, 2, 0, 0, 2, 0), target("spirv.IntegralConstant", i64, 4), 28, 0, 0) [[loaded]])
+ // CHECK: store target("spirv.Type", target("spirv.Image", float, 5, 2, 0, 0, 2, 0), target("spirv.IntegralConstant", i64, 4), 28, 0, 0) %call1, ptr %buffers2, align 4
+ ArrayBuffer<4> buffers2 = getArrayBuffer(buffers);
+
+ // CHECK: [[loaded:%[0-9]+]] = load target("spirv.Type", target("spirv.Literal", 32), target("spirv.Literal", 0), 21, 4, 32), ptr %word, align 4
+ // CHECK: %call2 = call spir_func target("spirv.Type", target("spirv.Literal", 32), target("spirv.Literal", 0), 21, 4, 32) @_Z6getIntu18spirv_type_21_4_32U4_LitLi32EU4_LitLi0E(target("spirv.Type", target("spirv.Literal", 32), target("spirv.Literal", 0), 21, 4, 32) [[loaded]])
+ // CHECK: store target("spirv.Type", target("spirv.Literal", 32), target("spirv.Literal", 0), 21, 4, 32) %call2, ptr %word2, align 4
+ Int word2 = getInt(word);
+}
diff --git a/clang/test/CodeGenHLSL/inline/SpirvType.incomplete.hlsl b/clang/test/CodeGenHLSL/inline/SpirvType.incomplete.hlsl
new file mode 100644
index 0000000000000..9f4596e6974ee
--- /dev/null
+++ b/clang/test/CodeGenHLSL/inline/SpirvType.incomplete.hlsl
@@ -0,0 +1,14 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN: spirv-unknown-vulkan-compute %s -fsyntax-only -verify
+
+struct S; // expected-note {{forward declaration of 'S'}}
+
+// expected-error at hlsl/hlsl_spirv.h:26 {{argument type 'S' is incomplete}}
+
+typedef vk::SpirvOpaqueType</* OpTypeArray */ 28, S, vk::integral_constant<uint, 4>> ArrayOfS; // #1
+// expected-note@#1 {{in instantiation of template type alias 'SpirvOpaqueType' requested here}}
+
+[numthreads(1, 1, 1)]
+void main() {
+ ArrayOfS buffers;
+}
diff --git a/clang/test/CodeGenHLSL/inline/SpirvType.literal.error.hlsl b/clang/test/CodeGenHLSL/inline/SpirvType.literal.error.hlsl
new file mode 100644
index 0000000000000..44d7e855ba5cd
--- /dev/null
+++ b/clang/test/CodeGenHLSL/inline/SpirvType.literal.error.hlsl
@@ -0,0 +1,11 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
+// RUN: spirv-unknown-vulkan-compute %s -fsyntax-only -verify
+
+// expected-error at hlsl/hlsl_spirv.h:20 {{the argument to vk::Literal must be a vk::integral_constant}}
+
+typedef vk::SpirvOpaqueType<28, vk::Literal<float>> T; // #1
+// expected-note@#1 {{in instantiation of template type alias 'SpirvOpaqueType' requested here}}
+
+[numthreads(1, 1, 1)]
+void main() {
+}
diff --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp
index 197ba2cd6856e..abbc6a7ccb6eb 100644
--- a/clang/tools/libclang/CIndex.cpp
+++ b/clang/tools/libclang/CIndex.cpp
@@ -1794,6 +1794,11 @@ bool CursorVisitor::VisitHLSLAttributedResourceTypeLoc(
return Visit(TL.getWrappedLoc());
}
+bool CursorVisitor::VisitHLSLInlineSpirvTypeLoc(HLSLInlineSpirvTypeLoc TL) {
+ // Nothing to do.
+ return false;
+}
+
bool CursorVisitor::VisitFunctionTypeLoc(FunctionTypeLoc TL,
bool SkipResultType) {
if (!SkipResultType && Visit(TL.getReturnLoc()))
diff --git a/clang/tools/libclang/CXType.cpp b/clang/tools/libclang/CXType.cpp
index 2c9ef282b8abc..225790a5ffd80 100644
--- a/clang/tools/libclang/CXType.cpp
+++ b/clang/tools/libclang/CXType.cpp
@@ -636,6 +636,7 @@ CXString clang_getTypeKindSpelling(enum CXTypeKind K) {
TKIND(Attributed);
TKIND(BTFTagAttributed);
TKIND(HLSLAttributedResource);
+ TKIND(HLSLInlineSpirv);
TKIND(BFloat16);
#define IMAGE_TYPE(ImgType, Id, SingletonId, Access, Suffix) TKIND(Id);
#include "clang/Basic/OpenCLImageTypes.def"
diff --git a/clang/utils/TableGen/ClangBuiltinTemplatesEmitter.cpp b/clang/utils/TableGen/ClangBuiltinTemplatesEmitter.cpp
index 34bc782e007d5..797e6c3f4d04b 100644
--- a/clang/utils/TableGen/ClangBuiltinTemplatesEmitter.cpp
+++ b/clang/utils/TableGen/ClangBuiltinTemplatesEmitter.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "TableGenBackends.h"
+#include "llvm/ADT/StringSet.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/TableGenBackend.h"
@@ -21,11 +22,14 @@ using namespace llvm;
static std::string TemplateNameList;
static std::string CreateBuiltinTemplateParameterList;
+static llvm::StringSet BuiltinClasses;
+
namespace {
struct ParserState {
size_t UniqueCounter = 0;
size_t CurrentDepth = 0;
bool EmittedSizeTInfo = false;
+ bool EmittedUint32TInfo = false;
};
std::pair<std::string, std::string>
@@ -62,7 +66,7 @@ ParseTemplateParameterList(ParserState &PS,
if (TemplateNameToParmName.find(Type.str()) ==
TemplateNameToParmName.end()) {
- PrintFatalError("Unkown Type Name");
+ PrintFatalError("Unknown Type Name");
}
auto TSIName = "TSI" + std::to_string(PS.UniqueCounter++);
@@ -75,19 +79,32 @@ ParseTemplateParameterList(ParserState &PS,
<< TSIName << "->getType(), " << Arg->getValueAsBit("IsVariadic")
<< ", " << TSIName << ");\n";
} else if (Arg->isSubClassOf("BuiltinNTTP")) {
- if (Arg->getValueAsString("TypeName") != "size_t")
- PrintFatalError("Unkown Type Name");
- if (!PS.EmittedSizeTInfo) {
- Code << "TypeSourceInfo *SizeTInfo = "
- "C.getTrivialTypeSourceInfo(C.getSizeType());\n";
- PS.EmittedSizeTInfo = true;
+ std::string SourceInfo;
+ if (Arg->getValueAsString("TypeName") == "size_t") {
+ SourceInfo = "SizeTInfo";
+ if (!PS.EmittedSizeTInfo) {
+ Code << "TypeSourceInfo *SizeTInfo = "
+ "C.getTrivialTypeSourceInfo(C.getSizeType());\n";
+ PS.EmittedSizeTInfo = true;
+ }
+ } else if (Arg->getValueAsString("TypeName") == "uint32_t") {
+ SourceInfo = "Uint32TInfo";
+ if (!PS.EmittedUint32TInfo) {
+ Code << "TypeSourceInfo *Uint32TInfo = "
+ "C.getTrivialTypeSourceInfo(C.UnsignedIntTy);\n";
+ PS.EmittedUint32TInfo = true;
+ }
+ } else {
+ PrintFatalError("Unknown Type Name");
}
Code << " auto *" << ParmName
<< " = NonTypeTemplateParmDecl::Create(C, DC, SourceLocation(), "
"SourceLocation(), "
- << PS.CurrentDepth << ", " << Position++
- << ", /*Id=*/nullptr, SizeTInfo->getType(), "
- "/*ParameterPack=*/false, SizeTInfo);\n";
+ << PS.CurrentDepth << ", " << Position++ << ", /*Id=*/nullptr, "
+ << SourceInfo
+ << "->getType(), "
+ "/*ParameterPack=*/false, "
+ << SourceInfo << ");\n";
} else {
PrintFatalError("Unknown Argument Type");
}
@@ -134,7 +151,8 @@ EmitCreateBuiltinTemplateParameterList(std::vector<const Record *> TemplateArgs,
CreateBuiltinTemplateParameterList += " }\n";
}
-void EmitBuiltinTemplate(raw_ostream &OS, const Record *BuiltinTemplate) {
+void EmitBuiltinTemplate(const Record *BuiltinTemplate) {
+ auto Class = BuiltinTemplate->getType()->getAsString();
auto Name = BuiltinTemplate->getName();
std::vector<const Record *> TemplateHead =
@@ -142,21 +160,49 @@ void EmitBuiltinTemplate(raw_ostream &OS, const Record *BuiltinTemplate) {
EmitCreateBuiltinTemplateParameterList(TemplateHead, Name);
- TemplateNameList += "BuiltinTemplate(";
+ TemplateNameList += Class + "(";
TemplateNameList += Name;
TemplateNameList += ")\n";
+
+ BuiltinClasses.insert(Class);
+}
+
+void EmitDefaultDefine(llvm::raw_ostream &OS, StringRef Name) {
+ OS << "#ifndef " << Name << "\n";
+ OS << "#define " << Name << "(NAME)" << " " << "BuiltinTemplate"
+ << "(NAME)\n";
+ OS << "#endif\n\n";
+}
+
+void EmitUndef(llvm::raw_ostream &OS, StringRef Name) {
+ OS << "#undef " << Name << "\n";
}
} // namespace
void clang::EmitClangBuiltinTemplates(const llvm::RecordKeeper &Records,
llvm::raw_ostream &OS) {
emitSourceFileHeader("Tables and code for Clang's builtin templates", OS);
+
for (const auto *Builtin :
Records.getAllDerivedDefinitions("BuiltinTemplate"))
- EmitBuiltinTemplate(OS, Builtin);
+ EmitBuiltinTemplate(Builtin);
+
+ for (const auto &ClassEntry : BuiltinClasses) {
+ StringRef Class = ClassEntry.getKey();
+ if (Class == "BuiltinTemplate")
+ continue;
+ EmitDefaultDefine(OS, Class);
+ }
OS << "#if defined(CREATE_BUILTIN_TEMPLATE_PARAMETER_LIST)\n"
<< CreateBuiltinTemplateParameterList
<< "#undef CREATE_BUILTIN_TEMPLATE_PARAMETER_LIST\n#else\n"
<< TemplateNameList << "#undef BuiltinTemplate\n#endif\n";
+
+ for (const auto &ClassEntry : BuiltinClasses) {
+ StringRef Class = ClassEntry.getKey();
+ if (Class == "BuiltinTemplate")
+ continue;
+ EmitUndef(OS, Class);
+ }
}
>From 7bdc924c57eb4b7e9e28143eff87086310e7b080 Mon Sep 17 00:00:00 2001
From: Cassandra Beckley <cbeckley at google.com>
Date: Tue, 1 Apr 2025 23:34:31 -0700
Subject: [PATCH 2/2] Fix formatting; remove unused code
---
clang/lib/Headers/hlsl/hlsl_spirv.h | 30 ++++++++++++++---------------
1 file changed, 14 insertions(+), 16 deletions(-)
diff --git a/clang/lib/Headers/hlsl/hlsl_spirv.h b/clang/lib/Headers/hlsl/hlsl_spirv.h
index 8a71699a4ed5c..711da2fea46a4 100644
--- a/clang/lib/Headers/hlsl/hlsl_spirv.h
+++ b/clang/lib/Headers/hlsl/hlsl_spirv.h
@@ -10,21 +10,19 @@
#define _HLSL_HLSL_SPIRV_H_
namespace hlsl {
- namespace vk {
- // template <class T> using Foo = __hlsl_spirv_t;
- // typedef Foo
- template <typename T, T v> struct integral_constant {
- static constexpr T value = v;
- };
-
- template <typename T> struct Literal {};
-
- template <uint Opcode, uint Size, uint Alignment, typename... Operands>
- using SpirvType = __hlsl_spirv_type<Opcode, Size, Alignment, Operands...>;
-
- template <uint Opcode, typename... Operands>
- using SpirvOpaqueType = __hlsl_spirv_type<Opcode, 0, 0, Operands...>;
- } // namespace vk
- } // namespace hlsl
+namespace vk {
+template <typename T, T v> struct integral_constant {
+ static constexpr T value = v;
+};
+
+template <typename T> struct Literal {};
+
+template <uint Opcode, uint Size, uint Alignment, typename... Operands>
+using SpirvType = __hlsl_spirv_type<Opcode, Size, Alignment, Operands...>;
+
+template <uint Opcode, typename... Operands>
+using SpirvOpaqueType = __hlsl_spirv_type<Opcode, 0, 0, Operands...>;
+} // namespace vk
+} // namespace hlsl
#endif // _HLSL_HLSL_SPIRV_H_
\ No newline at end of file
More information about the cfe-commits
mailing list