[clang] [HLSL] Parameter modifier parsing and AST (PR #72139)

Chris B via cfe-commits cfe-commits at lists.llvm.org
Wed Nov 15 07:23:52 PST 2023


https://github.com/llvm-beanz updated https://github.com/llvm/llvm-project/pull/72139

>From 8be56542f25b08ee5d6a325f76d12c28c4a366d7 Mon Sep 17 00:00:00 2001
From: Chris Bieneman <chris.bieneman at me.com>
Date: Mon, 13 Nov 2023 11:14:06 -0600
Subject: [PATCH 1/2] [HLSL] Parameter modifier parsing and AST

This change implements parsing for HLSL's parameter modifier keywords
`in`, `out` and `inout`. Because HLSL doesn't support references or
pointers, these keywords are used to allow parameters to be passed in
and out of functions.

This change only implements the parsing and AST support. In the HLSL
ASTs we represent `out` and `inout` parameters as references, and we
implement the semantics of by-value passing during IR generation.

In HLSL parameters marked `out` and `inout` are ambiguous in function
declarations, and `in`, `out` and `inout` may be ambiguous at call
sites.

This means a function may be defined as `fn(in T)` and `fn(inout T)` or
`fn(out T)`, but not `fn(inout T)` and `fn(out T)`. If a funciton `fn`
is declared with `in` and `inout` or `out` arguments, the call will be
ambiguous the same as a C++ call would be ambiguous given declarations
`fn(T)` and `fn(T&)`.
---
 clang/include/clang/Basic/Attr.td             | 12 ++++
 clang/include/clang/Basic/AttrDocs.td         | 19 ++++++
 .../clang/Basic/DiagnosticSemaKinds.td        |  4 ++
 clang/include/clang/Basic/TokenKinds.def      |  3 +
 clang/include/clang/Sema/Sema.h               |  3 +
 clang/lib/AST/TypePrinter.cpp                 |  4 ++
 clang/lib/Parse/ParseDecl.cpp                 | 14 +++-
 clang/lib/Parse/ParseTentative.cpp            |  3 +
 clang/lib/Sema/SemaDecl.cpp                   | 20 ++++++
 clang/lib/Sema/SemaDeclAttr.cpp               | 33 ++++++++++
 clang/lib/Sema/SemaType.cpp                   | 13 ++++
 clang/test/SemaHLSL/parameter_modifiers.hlsl  | 64 +++++++++++++++++++
 .../SemaHLSL/parameter_modifiers_ast.hlsl     | 36 +++++++++++
 13 files changed, 227 insertions(+), 1 deletion(-)
 create mode 100644 clang/test/SemaHLSL/parameter_modifiers.hlsl
 create mode 100644 clang/test/SemaHLSL/parameter_modifiers_ast.hlsl

diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 31434565becaec6..01eb0c3f9290780 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4229,6 +4229,18 @@ def HLSLGroupSharedAddressSpace : TypeAttr {
   let Documentation = [HLSLGroupSharedAddressSpaceDocs];
 }
 
+def HLSLParamModifier : TypeAttr {
+  let Spellings = [CustomKeyword<"in">, CustomKeyword<"inout">, CustomKeyword<"out">];
+  let Accessors = [Accessor<"isIn", [CustomKeyword<"in">]>,
+                   Accessor<"isInOut", [CustomKeyword<"inout">]>,
+                   Accessor<"isOut", [CustomKeyword<"out">]>,
+                   Accessor<"isAnyOut", [CustomKeyword<"out">, CustomKeyword<"inout">]>,
+                   Accessor<"isAnyIn", [CustomKeyword<"in">, CustomKeyword<"inout">]>];
+  let Subjects = SubjectList<[ParmVar]>;
+  let Documentation = [HLSLParamQualifierDocs];
+  let Args = [DefaultBoolArgument<"MergedSpelling", /*default*/0, /*fake*/1>];
+}
+
 def RandomizeLayout : InheritableAttr {
   let Spellings = [GCC<"randomize_layout">];
   let Subjects = SubjectList<[Record]>;
diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td
index fa6f6acd0c30e88..09090c94ae7fb73 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -7061,6 +7061,25 @@ The full documentation is available here: https://learn.microsoft.com/en-us/wind
   }];
 }
 
+def HLSLParamQualifierDocs : Documentation {
+  let Category = DocCatVariable;
+  let Content = [{
+HLSL function parameters are passed by value. Parameter declarations support
+three qualifiers to denote parameter passing behavior. The three qualifiers are
+`in`, `out` and `inout`.
+
+Parameters annotated with `in` or with no annotation are passed by value from
+the caller to the callee.
+
+Parameters annotated with `out` are written to the argument after the callee
+returns (Note: arguments values passed into `out` parameters _are not_ copied
+into the callee).
+
+Parameters annotated with `inout` are copied into the callee via a temporary,
+and copied back to the argument after the callee returns.
+  }];
+}
+
 def AnnotateTypeDocs : Documentation {
   let Category = DocCatType;
   let Heading = "annotate_type";
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 4614324babb1c91..313acc3499e3ffc 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -11989,6 +11989,7 @@ def err_hlsl_numthreads_argument_oor : Error<"argument '%select{X|Y|Z}0' to numt
 def err_hlsl_numthreads_invalid : Error<"total number of threads cannot exceed %0">;
 def err_hlsl_missing_numthreads : Error<"missing numthreads attribute for %0 shader entry">;
 def err_hlsl_attribute_param_mismatch : Error<"%0 attribute parameters do not match the previous declaration">;
+def err_hlsl_duplicate_parameter_modifier : Error<"duplicate parameter modifier %0">;
 def err_hlsl_missing_semantic_annotation : Error<
   "semantic annotations must be present for all parameters of an entry "
   "function or patch constant function">;
@@ -12004,6 +12005,9 @@ def err_hlsl_pointers_unsupported : Error<
 def err_hlsl_operator_unsupported : Error<
   "the '%select{&|*|->}0' operator is unsupported in HLSL">;
 
+def err_hlsl_param_qualifier_mismatch :
+  Error<"conflicting parameter qualifier %0 on parameter %1">;
+
 // Layout randomization diagnostics.
 def err_non_designated_init_used : Error<
   "a randomized struct can only be initialized with a designated initializer">;
diff --git a/clang/include/clang/Basic/TokenKinds.def b/clang/include/clang/Basic/TokenKinds.def
index 6cb4b3f250c4032..dda6e2611ba9ac0 100644
--- a/clang/include/clang/Basic/TokenKinds.def
+++ b/clang/include/clang/Basic/TokenKinds.def
@@ -626,6 +626,9 @@ KEYWORD(__noinline__                , KEYCUDA)
 KEYWORD(cbuffer                     , KEYHLSL)
 KEYWORD(tbuffer                     , KEYHLSL)
 KEYWORD(groupshared                 , KEYHLSL)
+KEYWORD(in                          , KEYHLSL)
+KEYWORD(inout                       , KEYHLSL)
+KEYWORD(out                         , KEYHLSL)
 
 // OpenMP Type Traits
 UNARY_EXPR_OR_TYPE_TRAIT(__builtin_omp_required_simd_align, OpenMPRequiredSimdAlign, KEYALL)
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 38377f01a10086f..316e9c129b5324f 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -3732,6 +3732,9 @@ class Sema final {
                                               int X, int Y, int Z);
   HLSLShaderAttr *mergeHLSLShaderAttr(Decl *D, const AttributeCommonInfo &AL,
                                       HLSLShaderAttr::ShaderType ShaderType);
+  HLSLParamModifierAttr *
+  mergeHLSLParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
+                             HLSLParamModifierAttr::Spelling Spelling);
 
   void mergeDeclAttributes(NamedDecl *New, Decl *Old,
                            AvailabilityMergeKind AMK = AMK_Redeclaration);
diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp
index e4f5f40cd625996..f69412429273674 100644
--- a/clang/lib/AST/TypePrinter.cpp
+++ b/clang/lib/AST/TypePrinter.cpp
@@ -1894,6 +1894,10 @@ void TypePrinter::printAttributedAfter(const AttributedType *T,
   case attr::ArmMveStrictPolymorphism:
     OS << "__clang_arm_mve_strict_polymorphism";
     break;
+
+  // Nothing to print for this attribute.
+  case attr::HLSLParamModifier:
+    break;
   }
   OS << "))";
 }
diff --git a/clang/lib/Parse/ParseDecl.cpp b/clang/lib/Parse/ParseDecl.cpp
index 78c3ab72979a007..f67a73404b9261a 100644
--- a/clang/lib/Parse/ParseDecl.cpp
+++ b/clang/lib/Parse/ParseDecl.cpp
@@ -4528,6 +4528,9 @@ void Parser::ParseDeclarationSpecifiers(
       break;
 
     case tok::kw_groupshared:
+    case tok::kw_in:
+    case tok::kw_inout:
+    case tok::kw_out:
       // NOTE: ParseHLSLQualifiers will consume the qualifier token.
       ParseHLSLQualifiers(DS.getAttributes());
       continue;
@@ -5559,7 +5562,6 @@ bool Parser::isTypeSpecifierQualifier() {
   case tok::kw___read_write:
   case tok::kw___write_only:
   case tok::kw___funcref:
-  case tok::kw_groupshared:
     return true;
 
   case tok::kw_private:
@@ -5568,6 +5570,13 @@ bool Parser::isTypeSpecifierQualifier() {
   // C11 _Atomic
   case tok::kw__Atomic:
     return true;
+
+  // HLSL type qualifiers
+  case tok::kw_groupshared:
+  case tok::kw_in:
+  case tok::kw_inout:
+  case tok::kw_out:
+    return getLangOpts().HLSL;
   }
 }
 
@@ -6067,6 +6076,9 @@ void Parser::ParseTypeQualifierListOpt(
       break;
 
     case tok::kw_groupshared:
+    case tok::kw_in:
+    case tok::kw_inout:
+    case tok::kw_out:
       // NOTE: ParseHLSLQualifiers will consume the qualifier token.
       ParseHLSLQualifiers(DS.getAttributes());
       continue;
diff --git a/clang/lib/Parse/ParseTentative.cpp b/clang/lib/Parse/ParseTentative.cpp
index 28decc4fc43f9b8..d403a71a6e973f0 100644
--- a/clang/lib/Parse/ParseTentative.cpp
+++ b/clang/lib/Parse/ParseTentative.cpp
@@ -1529,6 +1529,9 @@ Parser::isCXXDeclarationSpecifier(ImplicitTypenameContext AllowImplicitTypename,
 
     // HLSL address space qualifiers
   case tok::kw_groupshared:
+  case tok::kw_in:
+  case tok::kw_inout:
+  case tok::kw_out:
 
     // GNU
   case tok::kw_restrict:
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 3876eb501083acb..37a9472595b2171 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -3368,6 +3368,26 @@ static void mergeParamDeclAttributes(ParmVarDecl *newDecl,
            diag::note_carries_dependency_missing_first_decl) << 1/*Param*/;
   }
 
+  // HLSL parameter declarations for inout and out must match between
+  // declarations. In HLSL inout and out are ambiguous at the call site, but
+  // have different calling behavior, so you cannot overload a method based on a
+  // difference between inout and out annotations.
+  if (S.getLangOpts().HLSL) {
+    const auto *NDAttr = newDecl->getAttr<HLSLParamModifierAttr>();
+    const auto *ODAttr = oldDecl->getAttr<HLSLParamModifierAttr>();
+    // We don't need to cover the case where one declaration doesn't have an
+    // attribute. The only possible case there is if one declaration has an `in`
+    // attribute and the other declaration has no attribute. This case is
+    // allowed since parameters are `in` by default.
+    if (NDAttr && ODAttr &&
+        NDAttr->getSpellingListIndex() != ODAttr->getSpellingListIndex()) {
+      S.Diag(newDecl->getLocation(), diag::err_hlsl_param_qualifier_mismatch)
+          << NDAttr << newDecl;
+      S.Diag(oldDecl->getLocation(), diag::note_previous_declaration_as)
+          << ODAttr;
+    }
+  }
+
   if (!oldDecl->hasAttrs())
     return;
 
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index cdb769a883550d0..ea940437f762684 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -7306,6 +7306,36 @@ static void handleHLSLResourceBindingAttr(Sema &S, Decl *D,
     D->addAttr(NewAttr);
 }
 
+static void handleHLSLParamModifierAttr(Sema &S, Decl *D,
+                                        const ParsedAttr &AL) {
+  HLSLParamModifierAttr *NewAttr = S.mergeHLSLParamModifierAttr(
+      D, AL,
+      static_cast<HLSLParamModifierAttr::Spelling>(AL.getSemanticSpelling()));
+  if (NewAttr)
+    D->addAttr(NewAttr);
+}
+
+HLSLParamModifierAttr *
+Sema::mergeHLSLParamModifierAttr(Decl *D, const AttributeCommonInfo &AL,
+                                 HLSLParamModifierAttr::Spelling Spelling) {
+  // We can only merge an `in` attribute with an `out` attribute. All other
+  // combinations of duplicated attributes are ill-formed.
+  if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) {
+    if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) ||
+        (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) {
+      D->dropAttr<HLSLParamModifierAttr>();
+      SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()};
+      return HLSLParamModifierAttr::Create(
+          Context, /*MergedSpelling=*/true, AdjustedRange,
+          HLSLParamModifierAttr::Keyword_inout);
+    }
+    Diag(AL.getLoc(), diag::err_hlsl_duplicate_parameter_modifier) << AL;
+    Diag(PA->getLocation(), diag::note_conflicting_attribute);
+    return nullptr;
+  }
+  return HLSLParamModifierAttr::Create(Context, AL);
+}
+
 static void handleMSInheritanceAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
   if (!S.LangOpts.CPlusPlus) {
     S.Diag(AL.getLoc(), diag::err_attribute_not_supported_in_lang)
@@ -9456,6 +9486,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
   case ParsedAttr::AT_HLSLResourceBinding:
     handleHLSLResourceBindingAttr(S, D, AL);
     break;
+  case ParsedAttr::AT_HLSLParamModifier:
+    handleHLSLParamModifierAttr(S, D, AL);
+    break;
 
   case ParsedAttr::AT_AbiTag:
     handleAbiTagAttr(S, D, AL);
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index 560feafa1857cb3..cb384d6690d43dd 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -8661,6 +8661,13 @@ static void HandleLifetimeBoundAttr(TypeProcessingState &State,
   }
 }
 
+static void HandleHLSLParamModifierAttr(QualType &CurType,
+                                        const ParsedAttr &Attr, Sema &S) {
+  if (Attr.getSemanticSpelling() == HLSLParamModifierAttr::Keyword_inout ||
+      Attr.getSemanticSpelling() == HLSLParamModifierAttr::Keyword_out)
+    CurType = S.getASTContext().getLValueReferenceType(CurType);
+}
+
 static void processTypeAttrs(TypeProcessingState &state, QualType &type,
                              TypeAttrLocation TAL,
                              const ParsedAttributesView &attrs,
@@ -8837,6 +8844,12 @@ static void processTypeAttrs(TypeProcessingState &state, QualType &type,
       break;
     }
 
+    case ParsedAttr::AT_HLSLParamModifier: {
+      HandleHLSLParamModifierAttr(type, attr, state.getSema());
+      attr.setUsedAsTypeAttr();
+      break;
+    }
+
     MS_TYPE_ATTRS_CASELIST:
       if (!handleMSPointerTypeQualifierAttr(state, attr, type))
         attr.setUsedAsTypeAttr();
diff --git a/clang/test/SemaHLSL/parameter_modifiers.hlsl b/clang/test/SemaHLSL/parameter_modifiers.hlsl
new file mode 100644
index 000000000000000..d12c1ff95270910
--- /dev/null
+++ b/clang/test/SemaHLSL/parameter_modifiers.hlsl
@@ -0,0 +1,64 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library %s -verify
+void fn(in out float f); // #fn
+
+// expected-error@#fn2{{duplicate parameter modifier 'in'}}
+// expected-note@#fn2{{conflicting attribute is here}}
+void fn2(in in float f); // #fn2
+
+// expected-error@#fn3{{duplicate parameter modifier 'out'}}
+// expected-note@#fn3{{conflicting attribute is here}}
+void fn3(out out float f); // #fn3
+
+// expected-error@#fn4{{duplicate parameter modifier 'in'}}
+// expected-error@#fn4{{duplicate parameter modifier 'out'}}
+// expected-note@#fn4{{conflicting attribute is here}}
+// expected-note@#fn4{{conflicting attribute is here}}
+void fn4(inout in out float f); // #fn4
+
+// expected-error@#fn5{{duplicate parameter modifier 'in'}}
+// expected-note@#fn5{{conflicting attribute is here}}
+void fn5(inout in float f); // #fn5
+
+// expected-error@#fn6{{duplicate parameter modifier 'out'}}
+// expected-note@#fn6{{conflicting attribute is here}}
+void fn6(inout out float f); // #fn6
+
+// expected-error@#fn-def{{conflicting parameter qualifier 'out' on parameter 'f'}}
+// expected-note@#fn{{previously declared as 'inout' here}}
+void fn(out float f) { // #fn-def
+  f = 2;
+}
+
+// Overload resolution failure.
+void fn(in float f); // #fn-in
+
+void failOverloadResolution() {
+  float f = 1.0;
+  fn(f); // expected-error{{call to 'fn' is ambiguous}}
+  // expected-note@#fn-def{{candidate function}}
+  // expected-note@#fn-in{{candidate function}}
+}
+
+// No errors on these scenarios.
+
+// Alternating `inout` and `in out` spellings between declaration and
+// definitions is fine since they have the same semantic meaning.
+void fn7(inout float f);
+void fn7(in out float f) {}
+
+void fn8(in out float f);
+void fn8(inout float f) {}
+
+// These two declare two different functions (although calling them will be
+// ambiguous). This is equivalent to declaring a functiion that takes a
+// reference and a function that takes a value of the same type.
+void fn9(in float f);
+void fn9(out float f);
+
+// The `in` attribute is effectively optional. If no attribute is present it is
+// the same as `in`, so these declarations match the functions.
+void fn10(in float f);
+void fn10(float f) {}
+
+void fn11(float f);
+void fn11(in float f) {}
diff --git a/clang/test/SemaHLSL/parameter_modifiers_ast.hlsl b/clang/test/SemaHLSL/parameter_modifiers_ast.hlsl
new file mode 100644
index 000000000000000..faba95329649757
--- /dev/null
+++ b/clang/test/SemaHLSL/parameter_modifiers_ast.hlsl
@@ -0,0 +1,36 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library %s -ast-dump | FileCheck %s
+
+// CHECK: FunctionDecl {{.*}} fn 'void (float)'
+// CHECK-NEXT: ParmVarDecl {{.*}} f 'float'
+// CHECK-NOT: HLSLParamModifierAttr
+void fn(float f);
+
+// CHECK: FunctionDecl {{.*}}6 fn2 'void (float)'
+// CHECK-NEXT: ParmVarDecl {{.*}} f 'float'
+// CHECK-NEXT: HLSLParamModifierAttr {{.*}} in
+// CHECK-NOT: HLSLParamModifierAttr
+void fn2(in float f);
+
+// CHECK: FunctionDecl {{.*}} fn3 'void (float &)'
+// CHECK-NEXT: ParmVarDecl {{.*}} f 'float &'
+// CHECK-NEXT: HLSLParamModifierAttr {{.*}} out
+// CHECK-NOT: HLSLParamModifierAttr
+void fn3(out float f);
+
+// CHECK: FunctionDecl {{.*}} fn4 'void (float &)'
+// CHECK-NEXT: ParmVarDecl {{.*}} f 'float &'
+// CHECK-NEXT: HLSLParamModifierAttr {{.*}} inout
+// CHECK-NOT: HLSLParamModifierAttr
+void fn4(inout float f);
+
+// CHECK: FunctionDecl {{.*}} fn5 'void (float &)'
+// CHECK-NEXT: ParmVarDecl {{.*}} f 'float &'
+// CHECK-NEXT: HLSLParamModifierAttr {{.*}} inout MergedSpelling
+// CHECK-NOT: HLSLParamModifierAttr
+void fn5(out in float f);
+
+// CHECK: FunctionDecl {{.*}} fn6 'void (float &)'
+// CHECK-NEXT: ParmVarDecl {{.*}} f 'float &'
+// CHECK-NEXT: HLSLParamModifierAttr {{.*}} inout MergedSpelling
+// CHECK-NOT: HLSLParamModifierAttr
+void fn6(in out float f);

>From 39344561383ca65acd82f32ec95308fa1cc88f29 Mon Sep 17 00:00:00 2001
From: Chris Bieneman <chris.bieneman at me.com>
Date: Wed, 15 Nov 2023 09:16:35 -0600
Subject: [PATCH 2/2] Extend testing and handle template cases

Template instantiations for this attribute are a bit odd because HLSL
doesn't allow the construction of reference types from source spellings.

The way that we need to handle this is by not applying the type
modification to template dependent types. Then we can apply the
modification during instantation when the attribute is applied to the
instantiated template type.

This update also includes additional testing to cover parameter
modifier usage in C++. I've also added an additional test case that
verifies errors are produced when trying to use the groupshared keyword
in C++ (a fix of opportunity included in my initial change).
---
 .../lib/Sema/SemaTemplateInstantiateDecl.cpp  | 14 ++++++
 clang/lib/Sema/SemaType.cpp                   |  4 ++
 clang/test/ParserHLSL/hlsl_groupshared.cpp    | 12 +++++
 .../ParserHLSL/hlsl_parameter_modifiers.cpp   | 50 +++++++++++++++++++
 clang/test/SemaHLSL/parameter_modifiers.hlsl  | 30 +++++++++++
 .../SemaHLSL/parameter_modifiers_ast.hlsl     | 29 +++++++++++
 6 files changed, 139 insertions(+)
 create mode 100644 clang/test/ParserHLSL/hlsl_groupshared.cpp
 create mode 100644 clang/test/ParserHLSL/hlsl_parameter_modifiers.cpp

diff --git a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
index 011356e08a04297..7ecb437cb2ef3c1 100644
--- a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
+++ b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
@@ -663,6 +663,14 @@ static bool isRelevantAttr(Sema &S, const Decl *D, const Attr *A) {
   return true;
 }
 
+static void instantiateDependentHLSLParamModifierAttr(
+    Sema &S, const MultiLevelTemplateArgumentList &TemplateArgs,
+    const HLSLParamModifierAttr *Attr, Decl *New) {
+    ParmVarDecl *P = cast<ParmVarDecl>(New);
+    P->addAttr(Attr->clone(S.getASTContext()));
+    P->setType(S.getASTContext().getLValueReferenceType(P->getType()));
+}
+
 void Sema::InstantiateAttrsForDecl(
     const MultiLevelTemplateArgumentList &TemplateArgs, const Decl *Tmpl,
     Decl *New, LateInstantiatedAttrVec *LateAttrs,
@@ -784,6 +792,12 @@ void Sema::InstantiateAttrs(const MultiLevelTemplateArgumentList &TemplateArgs,
                                                *AMDGPUFlatWorkGroupSize, New);
     }
 
+    if (const auto *ParamAttr = dyn_cast<HLSLParamModifierAttr>(TmplAttr)) {
+      instantiateDependentHLSLParamModifierAttr(*this, TemplateArgs, ParamAttr,
+                                                New);
+      continue;
+    }
+
     // Existing DLL attribute on the instantiation takes precedence.
     if (TmplAttr->getKind() == attr::DLLExport ||
         TmplAttr->getKind() == attr::DLLImport) {
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index cb384d6690d43dd..56d133f20a29351 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -8663,6 +8663,10 @@ static void HandleLifetimeBoundAttr(TypeProcessingState &State,
 
 static void HandleHLSLParamModifierAttr(QualType &CurType,
                                         const ParsedAttr &Attr, Sema &S) {
+  // Don't apply this attribute to template dependent types. It is applied on
+  // substitution during template instantiation.
+  if (CurType->isDependentType())
+    return;
   if (Attr.getSemanticSpelling() == HLSLParamModifierAttr::Keyword_inout ||
       Attr.getSemanticSpelling() == HLSLParamModifierAttr::Keyword_out)
     CurType = S.getASTContext().getLValueReferenceType(CurType);
diff --git a/clang/test/ParserHLSL/hlsl_groupshared.cpp b/clang/test/ParserHLSL/hlsl_groupshared.cpp
new file mode 100644
index 000000000000000..eaa195a41147141
--- /dev/null
+++ b/clang/test/ParserHLSL/hlsl_groupshared.cpp
@@ -0,0 +1,12 @@
+// RUN: %clang_cc1 %s -verify
+extern groupshared float f; // expected-error{{unknown type name 'groupshared'}}
+
+extern float groupshared f2; // expected-error{{expected ';' after top level declarator}}
+
+namespace {
+float groupshared [[]] f3 = 12; // expected-error{{expected ';' after top level declarator}}
+}
+
+// expected-error@#fgc{{expected ';' after top level declarator}}
+// expected-error@#fgc{{a type specifier is required for all declarations}}
+float groupshared const f4 = 12; // #fgc
diff --git a/clang/test/ParserHLSL/hlsl_parameter_modifiers.cpp b/clang/test/ParserHLSL/hlsl_parameter_modifiers.cpp
new file mode 100644
index 000000000000000..cb553fbfa4d32fa
--- /dev/null
+++ b/clang/test/ParserHLSL/hlsl_parameter_modifiers.cpp
@@ -0,0 +1,50 @@
+// RUN: %clang_cc1 %s -verify
+
+// expected-error@#fn{{unknown type name 'in'}}
+// expected-error@#fn{{expected ')'}}
+// expected-note@#fn{{to match this '('}}
+void fn(in out float f); // #fn
+
+// expected-error@#fn2{{unknown type name 'in'}}
+// expected-error@#fn2{{expected ')'}}
+// expected-note@#fn2{{to match this '('}}
+void fn2(in in float f); // #fn2
+
+// expected-error@#fn3{{unknown type name 'out'}}
+// expected-error@#fn3{{expected ')'}}
+// expected-note@#fn3{{to match this '('}}
+void fn3(out out float f); // #fn3
+
+// expected-error@#fn4{{unknown type name 'inout'}}
+// expected-error@#fn4{{expected ')'}}
+// expected-note@#fn4{{to match this '('}}
+void fn4(inout in out float f); // #fn4
+
+// expected-error@#fn5{{unknown type name 'inout'}}
+// expected-error@#fn5{{expected ')'}}
+// expected-note@#fn5{{to match this '('}}
+void fn5(inout in float f); // #fn5
+
+// expected-error@#fn6{{unknown type name 'inout'}}
+// expected-error@#fn6{{expected ')'}}
+// expected-note@#fn6{{to match this '('}}
+void fn6(inout out float f); // #fn6
+
+void implicitFn(float f);
+
+// expected-error@#inFn{{unknown type name 'in'}}
+void inFn(in float f); // #inFn
+
+// expected-error@#inoutFn{{unknown type name 'inout'}}
+void inoutFn(inout float f); // #inoutFn
+
+// expected-error@#outFn{{unknown type name 'out'}}
+void outFn(out float f); // #outFn
+
+// expected-error@#fn7{{unknown type name 'inout'}}
+// expected-error@#fn7{{declaration of 'T' shadows template parameter}}
+// expected-error@#fn7{{expected ')'}}
+// expected-note@#fn7{{to match this '('}}
+template <typename T> // expected-note{{template parameter is declared here}}
+void fn7(inout T f); // #fn7
+
diff --git a/clang/test/SemaHLSL/parameter_modifiers.hlsl b/clang/test/SemaHLSL/parameter_modifiers.hlsl
index d12c1ff95270910..dd608115aa1d9c9 100644
--- a/clang/test/SemaHLSL/parameter_modifiers.hlsl
+++ b/clang/test/SemaHLSL/parameter_modifiers.hlsl
@@ -39,6 +39,28 @@ void failOverloadResolution() {
   // expected-note@#fn-in{{candidate function}}
 }
 
+void implicitFn(float f);
+void inFn(in float f);
+void inoutFn(inout float f); // #inoutFn
+void outFn(out float f); // #outFn
+
+void callFns() {
+  // Call with literal arguments.
+  implicitFn(1); // Ok.
+  inFn(1); // Ok.
+  inoutFn(1); // expected-error{{no matching function for call to 'inoutFn'}}
+  // expected-note@#inoutFn{{candidate function not viable: no known conversion from 'int' to 'float &' for 1st argument}}
+  outFn(1); // expected-error{{no matching function for call to 'outFn}}
+  // expected-note@#outFn{{candidate function not viable: no known conversion from 'int' to 'float &' for 1st argument}}
+  
+  // Call with variables.
+  float f;
+  implicitFn(f); // Ok.
+  inFn(f); // Ok.
+  inoutFn(f); // Ok.
+  outFn(f); // Ok.
+}
+
 // No errors on these scenarios.
 
 // Alternating `inout` and `in out` spellings between declaration and
@@ -62,3 +84,11 @@ void fn10(float f) {}
 
 void fn11(float f);
 void fn11(in float f) {}
+
+template <typename T>
+void fn12(inout T f);
+
+void fn13() {
+  float f;
+  fn12<float>(f);
+}
diff --git a/clang/test/SemaHLSL/parameter_modifiers_ast.hlsl b/clang/test/SemaHLSL/parameter_modifiers_ast.hlsl
index faba95329649757..50b162bdfc26cc0 100644
--- a/clang/test/SemaHLSL/parameter_modifiers_ast.hlsl
+++ b/clang/test/SemaHLSL/parameter_modifiers_ast.hlsl
@@ -34,3 +34,32 @@ void fn5(out in float f);
 // CHECK-NEXT: HLSLParamModifierAttr {{.*}} inout MergedSpelling
 // CHECK-NOT: HLSLParamModifierAttr
 void fn6(in out float f);
+
+// CHECK-NEXT: FunctionTemplateDecl [[Template:0x[0-9a-fA-F]+]] {{.*}} fn7
+// CHECK-NEXT: TemplateTypeParmDecl {{.*}} referenced typename depth 0 index 0 T
+// CHECK-NEXT: FunctionDecl {{.*}} fn7 'void (T)'
+// CHECK-NEXT: ParmVarDecl {{.*}} f 'T'
+// CHECK-NEXT: HLSLParamModifierAttr {{.*}} inout
+// CHECK-NEXT: FunctionDecl [[Instantiation:0x[0-9a-fA-F]+]] {{.*}} used fn7 'void (float &)' implicit_instantiation
+// CHECK-NEXT: TemplateArgument type 'float'
+// CHECK-NEXT:  BuiltinType {{.*}} 'float'
+// CHECK-NEXT: ParmVarDecl {{.*}} f 'float &'
+// CHECK-NEXT: HLSLParamModifierAttr {{.*}} inout
+
+template <typename T>
+void fn7(inout T f);
+
+// CHECK: FunctionDecl {{.*}} fn8 'void ()'
+// CHECK-NEXT: CompoundStmt
+// CHECK-NEXT: DeclStmt
+// CHECK-NEXT: VarDecl {{.*}} used f 'float'
+// CHECK-NEXT: CallExpr {{.*}} 'void'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'void (*)(float &)' <FunctionToPointerDecay>
+// CHECK-NEXT: DeclRefExpr {{.*}} 'void (float &)' lvalue
+// CHECK-SAME: Function [[Instantiation]] 'fn7' 'void (float &)'
+// CHECK-SAME: (FunctionTemplate [[Template]] 'fn7')
+// CHECK-NEXT: DeclRefExpr {{.*}} 'float' lvalue Var {{.*}} 'f' 'float'
+void fn8() {
+  float f;
+  fn7<float>(f);
+}



More information about the cfe-commits mailing list