[clang] b8dbc6f - [HLSL] Add ExternalSemaSource & vector alias

Chris Bieneman via cfe-commits cfe-commits at lists.llvm.org
Tue Jul 5 09:30:41 PDT 2022


Author: Chris Bieneman
Date: 2022-07-05T11:30:29-05:00
New Revision: b8dbc6ffea93976dc0d8569c9d23e9c21e33e317

URL: https://github.com/llvm/llvm-project/commit/b8dbc6ffea93976dc0d8569c9d23e9c21e33e317
DIFF: https://github.com/llvm/llvm-project/commit/b8dbc6ffea93976dc0d8569c9d23e9c21e33e317.diff

LOG: [HLSL] Add ExternalSemaSource & vector alias

HLSL vector types are ext_vector types, but they are also exposed via a
template syntax `vector<T, #>`. This is morally equavalent to the code:

```c++
template <typename T, int Size>
using vector = T __attribute__((ext_vector_type(Size)))
```

The problem is that templates aren't supported before HLSL 2021, and
type aliases still aren't supported in HLSL.

To resolve this (and other issues where HLSL can't represent its own
types), we rely on an external AST & Sema source being registered for
HLSL code.

This patch adds the HLSLExternalSemaSource and registers the vector
type alias.

Depends on D127802

Differential Revision: https://reviews.llvm.org/D128012

Added: 
    clang/include/clang/Sema/HLSLExternalSemaSource.h
    clang/lib/Sema/HLSLExternalSemaSource.cpp
    clang/test/AST/HLSL/vector-alias.hlsl

Modified: 
    clang/lib/Frontend/FrontendAction.cpp
    clang/lib/Headers/hlsl/hlsl_basic_types.h
    clang/lib/Sema/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/Sema/HLSLExternalSemaSource.h b/clang/include/clang/Sema/HLSLExternalSemaSource.h
new file mode 100644
index 0000000000000..92154427a3e72
--- /dev/null
+++ b/clang/include/clang/Sema/HLSLExternalSemaSource.h
@@ -0,0 +1,43 @@
+//===--- HLSLExternalSemaSource.h - HLSL Sema Source ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+//  This file defines the HLSLExternalSemaSource interface.
+//
+//===----------------------------------------------------------------------===//
+#ifndef CLANG_SEMA_HLSLEXTERNALSEMASOURCE_H
+#define CLANG_SEMA_HLSLEXTERNALSEMASOURCE_H
+
+#include "clang/Sema/ExternalSemaSource.h"
+
+namespace clang {
+class NamespaceDecl;
+class Sema;
+
+class HLSLExternalSemaSource : public ExternalSemaSource {
+  static char ID;
+
+  Sema *SemaPtr = nullptr;
+  NamespaceDecl *HLSLNamespace;
+
+  void defineHLSLVectorAlias();
+
+public:
+  ~HLSLExternalSemaSource() override;
+
+  /// Initialize the semantic source with the Sema instance
+  /// being used to perform semantic analysis on the abstract syntax
+  /// tree.
+  void InitializeSema(Sema &S) override;
+
+  /// Inform the semantic consumer that Sema is no longer available.
+  void ForgetSema() override { SemaPtr = nullptr; }
+};
+
+} // namespace clang
+
+#endif // CLANG_SEMA_HLSLEXTERNALSEMASOURCE_H

diff  --git a/clang/lib/Frontend/FrontendAction.cpp b/clang/lib/Frontend/FrontendAction.cpp
index 65160dd7e0b18..81915e6330b03 100644
--- a/clang/lib/Frontend/FrontendAction.cpp
+++ b/clang/lib/Frontend/FrontendAction.cpp
@@ -24,6 +24,7 @@
 #include "clang/Lex/Preprocessor.h"
 #include "clang/Lex/PreprocessorOptions.h"
 #include "clang/Parse/ParseAST.h"
+#include "clang/Sema/HLSLExternalSemaSource.h"
 #include "clang/Serialization/ASTDeserializationListener.h"
 #include "clang/Serialization/ASTReader.h"
 #include "clang/Serialization/GlobalModuleIndex.h"
@@ -1014,6 +1015,13 @@ bool FrontendAction::BeginSourceFile(CompilerInstance &CI,
     CI.getASTContext().setExternalSource(Override);
   }
 
+  // Setup HLSL External Sema Source
+  if (CI.getLangOpts().HLSL && CI.hasASTContext()) {
+    IntrusiveRefCntPtr<ExternalASTSource> HLSLSema(
+        new HLSLExternalSemaSource());
+    CI.getASTContext().setExternalSource(HLSLSema);
+  }
+
   FailureCleanup.release();
   return true;
 }

diff  --git a/clang/lib/Headers/hlsl/hlsl_basic_types.h b/clang/lib/Headers/hlsl/hlsl_basic_types.h
index 2069990f5c06c..e68715f1a6a45 100644
--- a/clang/lib/Headers/hlsl/hlsl_basic_types.h
+++ b/clang/lib/Headers/hlsl/hlsl_basic_types.h
@@ -27,38 +27,38 @@ typedef long int64_t;
 // built-in vector data types:
 
 #ifdef __HLSL_ENABLE_16_BIT
-typedef int16_t int16_t2 __attribute__((ext_vector_type(2)));
-typedef int16_t int16_t3 __attribute__((ext_vector_type(3)));
-typedef int16_t int16_t4 __attribute__((ext_vector_type(4)));
-typedef uint16_t uint16_t2 __attribute__((ext_vector_type(2)));
-typedef uint16_t uint16_t3 __attribute__((ext_vector_type(3)));
-typedef uint16_t uint16_t4 __attribute__((ext_vector_type(4)));
+typedef vector<int16_t, 2> int16_t2;
+typedef vector<int16_t, 3> int16_t3;
+typedef vector<int16_t, 4> int16_t4;
+typedef vector<uint16_t, 2> uint16_t2;
+typedef vector<uint16_t, 3> uint16_t3;
+typedef vector<uint16_t, 4> uint16_t4;
 #endif
 
-typedef int int2 __attribute__((ext_vector_type(2)));
-typedef int int3 __attribute__((ext_vector_type(3)));
-typedef int int4 __attribute__((ext_vector_type(4)));
-typedef uint uint2 __attribute__((ext_vector_type(2)));
-typedef uint uint3 __attribute__((ext_vector_type(3)));
-typedef uint uint4 __attribute__((ext_vector_type(4)));
-typedef int64_t int64_t2 __attribute__((ext_vector_type(2)));
-typedef int64_t int64_t3 __attribute__((ext_vector_type(3)));
-typedef int64_t int64_t4 __attribute__((ext_vector_type(4)));
-typedef uint64_t uint64_t2 __attribute__((ext_vector_type(2)));
-typedef uint64_t uint64_t3 __attribute__((ext_vector_type(3)));
-typedef uint64_t uint64_t4 __attribute__((ext_vector_type(4)));
+typedef vector<int, 2> int2;
+typedef vector<int, 3> int3;
+typedef vector<int, 4> int4;
+typedef vector<uint, 2> uint2;
+typedef vector<uint, 3> uint3;
+typedef vector<uint, 4> uint4;
+typedef vector<int64_t, 2> int64_t2;
+typedef vector<int64_t, 3> int64_t3;
+typedef vector<int64_t, 4> int64_t4;
+typedef vector<uint64_t, 2> uint64_t2;
+typedef vector<uint64_t, 3> uint64_t3;
+typedef vector<uint64_t, 4> uint64_t4;
 
 #ifdef __HLSL_ENABLE_16_BIT
-typedef half half2 __attribute__((ext_vector_type(2)));
-typedef half half3 __attribute__((ext_vector_type(3)));
-typedef half half4 __attribute__((ext_vector_type(4)));
+typedef vector<half, 2> half2;
+typedef vector<half, 3> half3;
+typedef vector<half, 4> half4;
 #endif
 
-typedef float float2 __attribute__((ext_vector_type(2)));
-typedef float float3 __attribute__((ext_vector_type(3)));
-typedef float float4 __attribute__((ext_vector_type(4)));
-typedef double double2 __attribute__((ext_vector_type(2)));
-typedef double double3 __attribute__((ext_vector_type(3)));
-typedef double double4 __attribute__((ext_vector_type(4)));
+typedef vector<float, 2> float2;
+typedef vector<float, 3> float3;
+typedef vector<float, 4> float4;
+typedef vector<double, 2> double2;
+typedef vector<double, 3> double3;
+typedef vector<double, 4> double4;
 
 #endif //_HLSL_HLSL_BASIC_TYPES_H_

diff  --git a/clang/lib/Sema/CMakeLists.txt b/clang/lib/Sema/CMakeLists.txt
index 0e0681a8e2927..2901dc4401a30 100644
--- a/clang/lib/Sema/CMakeLists.txt
+++ b/clang/lib/Sema/CMakeLists.txt
@@ -15,6 +15,7 @@ add_clang_library(clangSema
   CodeCompleteConsumer.cpp
   DeclSpec.cpp
   DelayedDiagnostic.cpp
+  HLSLExternalSemaSource.cpp
   IdentifierResolver.cpp
   JumpDiagnostics.cpp
   MultiplexExternalSemaSource.cpp

diff  --git a/clang/lib/Sema/HLSLExternalSemaSource.cpp b/clang/lib/Sema/HLSLExternalSemaSource.cpp
new file mode 100644
index 0000000000000..2de90378baf90
--- /dev/null
+++ b/clang/lib/Sema/HLSLExternalSemaSource.cpp
@@ -0,0 +1,98 @@
+//===--- HLSLExternalSemaSource.cpp - HLSL Sema Source --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/Sema/HLSLExternalSemaSource.h"
+#include "clang/AST/ASTContext.h"
+#include "clang/AST/DeclCXX.h"
+#include "clang/Basic/AttrKinds.h"
+#include "clang/Sema/Sema.h"
+
+using namespace clang;
+
+char HLSLExternalSemaSource::ID;
+
+HLSLExternalSemaSource::~HLSLExternalSemaSource() {}
+
+void HLSLExternalSemaSource::InitializeSema(Sema &S) {
+  SemaPtr = &S;
+  ASTContext &AST = SemaPtr->getASTContext();
+  IdentifierInfo &HLSL = AST.Idents.get("hlsl", tok::TokenKind::identifier);
+  HLSLNamespace =
+      NamespaceDecl::Create(AST, AST.getTranslationUnitDecl(), false,
+                            SourceLocation(), SourceLocation(), &HLSL, nullptr);
+  HLSLNamespace->setImplicit(true);
+  AST.getTranslationUnitDecl()->addDecl(HLSLNamespace);
+  defineHLSLVectorAlias();
+
+  // This adds a `using namespace hlsl` directive. In DXC, we don't put HLSL's
+  // built in types inside a namespace, but we are planning to change that in
+  // the near future. In order to be source compatible older versions of HLSL
+  // will need to implicitly use the hlsl namespace. For now in clang everything
+  // will get added to the namespace, and we can remove the using directive for
+  // future language versions to match HLSL's evolution.
+  auto *UsingDecl = UsingDirectiveDecl::Create(
+      AST, AST.getTranslationUnitDecl(), SourceLocation(), SourceLocation(),
+      NestedNameSpecifierLoc(), SourceLocation(), HLSLNamespace,
+      AST.getTranslationUnitDecl());
+
+  AST.getTranslationUnitDecl()->addDecl(UsingDecl);
+}
+
+void HLSLExternalSemaSource::defineHLSLVectorAlias() {
+  ASTContext &AST = SemaPtr->getASTContext();
+
+  llvm::SmallVector<NamedDecl *> TemplateParams;
+
+  auto *TypeParam = TemplateTypeParmDecl::Create(
+      AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 0,
+      &AST.Idents.get("element", tok::TokenKind::identifier), false, false);
+  TypeParam->setDefaultArgument(AST.getTrivialTypeSourceInfo(AST.FloatTy));
+
+  TemplateParams.emplace_back(TypeParam);
+
+  auto *SizeParam = NonTypeTemplateParmDecl::Create(
+      AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 1,
+      &AST.Idents.get("element_count", tok::TokenKind::identifier), AST.IntTy,
+      false, AST.getTrivialTypeSourceInfo(AST.IntTy));
+  Expr *LiteralExpr =
+      IntegerLiteral::Create(AST, llvm::APInt(AST.getIntWidth(AST.IntTy), 4),
+                             AST.IntTy, SourceLocation());
+  SizeParam->setDefaultArgument(LiteralExpr);
+  TemplateParams.emplace_back(SizeParam);
+
+  auto *ParamList =
+      TemplateParameterList::Create(AST, SourceLocation(), SourceLocation(),
+                                    TemplateParams, SourceLocation(), nullptr);
+
+  IdentifierInfo &II = AST.Idents.get("vector", tok::TokenKind::identifier);
+
+  QualType AliasType = AST.getDependentSizedExtVectorType(
+      AST.getTemplateTypeParmType(0, 0, false, TypeParam),
+      DeclRefExpr::Create(
+          AST, NestedNameSpecifierLoc(), SourceLocation(), SizeParam, false,
+          DeclarationNameInfo(SizeParam->getDeclName(), SourceLocation()),
+          AST.IntTy, VK_LValue),
+      SourceLocation());
+
+  auto *Record = TypeAliasDecl::Create(AST, HLSLNamespace, SourceLocation(),
+                                       SourceLocation(), &II,
+                                       AST.getTrivialTypeSourceInfo(AliasType));
+  Record->setImplicit(true);
+
+  auto *Template =
+      TypeAliasTemplateDecl::Create(AST, HLSLNamespace, SourceLocation(),
+                                    Record->getIdentifier(), ParamList, Record);
+
+  Record->setDescribedAliasTemplate(Template);
+  Template->setImplicit(true);
+  Template->setLexicalDeclContext(Record->getDeclContext());
+  HLSLNamespace->addDecl(Template);
+}

diff  --git a/clang/test/AST/HLSL/vector-alias.hlsl b/clang/test/AST/HLSL/vector-alias.hlsl
new file mode 100644
index 0000000000000..effa1aa53db49
--- /dev/null
+++ b/clang/test/AST/HLSL/vector-alias.hlsl
@@ -0,0 +1,53 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -ast-dump -o - %s | FileCheck %s 
+
+// CHECK: NamespaceDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit hlsl
+// CHECK-NEXT: TypeAliasTemplateDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit vector
+// CHECK-NEXT: TemplateTypeParmDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> class depth 0 index 0 element
+// CHECK-NEXT: TemplateArgument type 'float'
+// CHECK-NEXT: BuiltinType 0x{{[0-9a-fA-F]+}} 'float'
+// CHECK-NEXT: NonTypeTemplateParmDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> 'int' depth 0 index 1 element_count
+// CHECK-NEXT: TemplateArgument expr
+// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <<invalid sloc>> 'int' 4
+// CHECK-NEXT: TypeAliasDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit vector 'element __attribute__((ext_vector_type(element_count)))'
+// CHECK-NEXT: DependentSizedExtVectorType 0x{{[0-9a-fA-F]+}} 'element __attribute__((ext_vector_type(element_count)))' dependent <invalid sloc>
+// CHECK-NEXT: TemplateTypeParmType 0x{{[0-9a-fA-F]+}} 'element' dependent depth 0 index 0
+// CHECK-NEXT: TemplateTypeParm 0x{{[0-9a-fA-F]+}} 'element'
+// CHECK-NEXT: DeclRefExpr 0x{{[0-9a-fA-F]+}} <<invalid sloc>> 'int' lvalue
+// NonTypeTemplateParm 0x{{[0-9a-fA-F]+}} 'element_count' 'int'
+
+// Make sure we got a using directive at the end.
+// CHECK: UsingDirectiveDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> Namespace 0x{{[0-9a-fA-F]+}} 'hlsl'
+
+[numthreads(1,1,1)]
+int entry() {
+  // Verify that the alias is generated inside the hlsl namespace.
+  hlsl::vector<float, 2> Vec2 = {1.0, 2.0};
+
+  // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:24:3, col:43>
+  // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:42> col:26 Vec2 'hlsl::vector<float, 2>':'float __attribute__((ext_vector_type(2)))' cinit
+
+  // Verify that you don't need to specify the namespace.
+  vector<int, 2> Vec2a = {1, 2};
+
+  // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:30:3, col:32>
+  // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:31> col:18 Vec2a 'vector<int, 2>':'int __attribute__((ext_vector_type(2)))' cinit
+
+  // Build a bigger vector.
+  vector<double, 4> Vec4 = {1.0, 2.0, 3.0, 4.0};
+
+  // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:36:3, col:48>
+  // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:47> col:21 used Vec4 'vector<double, 4>':'double __attribute__((ext_vector_type(4)))' cinit
+
+  // Verify that swizzles still work.
+  vector<double, 3> Vec3 = Vec4.xyz;
+
+  // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:42:3, col:36>
+  // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:33> col:21 Vec3 'vector<double, 3>':'double __attribute__((ext_vector_type(3)))' cinit
+
+  // Verify that the implicit arguments generate the correct type.
+  vector<> ImpVec4 = {1.0, 2.0, 3.0, 4.0};
+
+  // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:48:3, col:42>
+  // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:41> col:12 ImpVec4 'vector<>':'float __attribute__((ext_vector_type(4)))' cinit
+  return 1;
+}


        


More information about the cfe-commits mailing list