[clang] Enable matrices in HLSL (PR #111415)

via cfe-commits cfe-commits at lists.llvm.org
Mon Oct 7 11:09:12 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clang

Author: Greg Roth (pow2clk)

<details>
<summary>Changes</summary>

HLSL needs matrix support. This allows matrices when the target language is HLSL and defines a matrix template alias that allows the short forms of matrix types to be defined in typedefs in a default header. Makes some changes to how matrices are printed in HLSL policy. This required a tweak to how existing matrices are printed in diagnostics that is more consistent with other types.

These matrix types will function exactly as the clang matrix extension dictates. Alterations to that behavior both specific to HLSL and also potentially expanding on the matrix extension will follow.

fixes #<!-- -->109839

---

Patch is 222.61 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/111415.diff


25 Files Affected:

- (modified) clang/include/clang/Sema/HLSLExternalSemaSource.h (+1) 
- (modified) clang/lib/AST/ASTContext.cpp (+1-1) 
- (modified) clang/lib/AST/TypePrinter.cpp (+28-12) 
- (modified) clang/lib/Headers/hlsl/hlsl_basic_types.h (+232) 
- (modified) clang/lib/Sema/HLSLExternalSemaSource.cpp (+73) 
- (modified) clang/lib/Sema/SemaType.cpp (+1-1) 
- (added) clang/test/AST/HLSL/matrix-alias.hlsl (+49) 
- (modified) clang/test/AST/HLSL/vector-alias.hlsl (+1-1) 
- (modified) clang/test/CodeGenCXX/matrix-type-operators.cpp (+3-5) 
- (added) clang/test/CodeGenHLSL/Types/BuiltinMatrix/matrix-cast-template.hlsl (+351) 
- (added) clang/test/CodeGenHLSL/Types/BuiltinMatrix/matrix-cast.hlsl (+162) 
- (added) clang/test/CodeGenHLSL/Types/BuiltinMatrix/matrix-transpose-template.hlsl (+82) 
- (added) clang/test/CodeGenHLSL/Types/BuiltinMatrix/matrix-transpose.hlsl (+95) 
- (added) clang/test/CodeGenHLSL/Types/BuiltinMatrix/matrix-type-operators-template.hlsl (+447) 
- (added) clang/test/CodeGenHLSL/Types/BuiltinMatrix/matrix-type-operators.hlsl (+1515) 
- (added) clang/test/CodeGenHLSL/Types/BuiltinMatrix/matrix-type.hlsl (+217) 
- (modified) clang/test/CodeGenHLSL/basic_types.hlsl (+18) 
- (added) clang/test/CodeGenHLSL/matrix-types.hlsl (+348) 
- (modified) clang/test/Sema/matrix-type-operators.c (+8-8) 
- (added) clang/test/SemaHLSL/Types/BuiltinMatrix/matrix-cast.hlsl (+138) 
- (added) clang/test/SemaHLSL/Types/BuiltinMatrix/matrix-index-operator-type.hlsl (+27) 
- (added) clang/test/SemaHLSL/Types/BuiltinMatrix/matrix-transpose.hlsl (+56) 
- (added) clang/test/SemaHLSL/Types/BuiltinMatrix/matrix-type-operators.hlsl (+307) 
- (added) clang/test/SemaHLSL/Types/BuiltinMatrix/matrix-type.hlsl (+48) 
- (modified) clang/test/SemaTemplate/matrix-type.cpp (+1-1) 


``````````diff
diff --git a/clang/include/clang/Sema/HLSLExternalSemaSource.h b/clang/include/clang/Sema/HLSLExternalSemaSource.h
index 3c7495e66055dc..6f4b72045a9464 100644
--- a/clang/include/clang/Sema/HLSLExternalSemaSource.h
+++ b/clang/include/clang/Sema/HLSLExternalSemaSource.h
@@ -28,6 +28,7 @@ class HLSLExternalSemaSource : public ExternalSemaSource {
   llvm::DenseMap<CXXRecordDecl *, CompletionFunction> Completions;
 
   void defineHLSLVectorAlias();
+  void defineHLSLMatrixAlias();
   void defineTrivialHLSLTypes();
   void defineHLSLTypesWithForwardDeclarations();
 
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index a81429ad6a2380..ed10c210ed170f 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -1381,7 +1381,7 @@ void ASTContext::InitBuiltinTypes(const TargetInfo &Target,
   if (LangOpts.OpenACC && !LangOpts.OpenMP) {
     InitBuiltinType(ArraySectionTy, BuiltinType::ArraySection);
   }
-  if (LangOpts.MatrixTypes)
+  if (LangOpts.MatrixTypes || LangOpts.HLSL)
     InitBuiltinType(IncompleteMatrixIdxTy, BuiltinType::IncompleteMatrixIdx);
 
   // Builtin types for 'id', 'Class', and 'SEL'.
diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp
index ca75bb97c158e1..142717201557f3 100644
--- a/clang/lib/AST/TypePrinter.cpp
+++ b/clang/lib/AST/TypePrinter.cpp
@@ -852,34 +852,50 @@ void TypePrinter::printExtVectorAfter(const ExtVectorType *T, raw_ostream &OS) {
 
 void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T,
                                             raw_ostream &OS) {
+  if (Policy.UseHLSLTypes)
+    OS << "matrix<";
   printBefore(T->getElementType(), OS);
-  OS << " __attribute__((matrix_type(";
-  OS << T->getNumRows() << ", " << T->getNumColumns();
-  OS << ")))";
 }
 
 void TypePrinter::printConstantMatrixAfter(const ConstantMatrixType *T,
                                            raw_ostream &OS) {
   printAfter(T->getElementType(), OS);
+  if (Policy.UseHLSLTypes) {
+    OS << ", ";
+    OS << T->getNumRows() << ", " << T->getNumColumns();
+    OS << ">";
+  } else {
+    OS << " __attribute__((matrix_type(";
+    OS << T->getNumRows() << ", " << T->getNumColumns();
+    OS << ")))";
+  }
 }
 
 void TypePrinter::printDependentSizedMatrixBefore(
     const DependentSizedMatrixType *T, raw_ostream &OS) {
+  if (Policy.UseHLSLTypes)
+    OS << "matrix<";
   printBefore(T->getElementType(), OS);
-  OS << " __attribute__((matrix_type(";
-  if (T->getRowExpr()) {
-    T->getRowExpr()->printPretty(OS, nullptr, Policy);
-  }
-  OS << ", ";
-  if (T->getColumnExpr()) {
-    T->getColumnExpr()->printPretty(OS, nullptr, Policy);
-  }
-  OS << ")))";
 }
 
 void TypePrinter::printDependentSizedMatrixAfter(
     const DependentSizedMatrixType *T, raw_ostream &OS) {
   printAfter(T->getElementType(), OS);
+  if (Policy.UseHLSLTypes)
+    OS << ", ";
+  else
+    OS << " __attribute__((matrix_type(";
+
+  if (Expr *E = T->getRowExpr())
+    E->printPretty(OS, nullptr, Policy);
+  OS << ", ";
+  if (Expr *E = T->getColumnExpr())
+    E->printPretty(OS, nullptr, Policy);
+
+  if (Policy.UseHLSLTypes)
+    OS << ">";
+  else
+    OS << ")))";
 }
 
 void
diff --git a/clang/lib/Headers/hlsl/hlsl_basic_types.h b/clang/lib/Headers/hlsl/hlsl_basic_types.h
index eff94e0d7f9500..b6eeffa2f5e362 100644
--- a/clang/lib/Headers/hlsl/hlsl_basic_types.h
+++ b/clang/lib/Headers/hlsl/hlsl_basic_types.h
@@ -115,6 +115,238 @@ typedef vector<float64_t, 2> float64_t2;
 typedef vector<float64_t, 3> float64_t3;
 typedef vector<float64_t, 4> float64_t4;
 
+#ifdef __HLSL_ENABLE_16_BIT
+typedef matrix<int16_t, 1, 1> int16_t1x1;
+typedef matrix<int16_t, 1, 2> int16_t1x2;
+typedef matrix<int16_t, 1, 3> int16_t1x3;
+typedef matrix<int16_t, 1, 4> int16_t1x4;
+typedef matrix<int16_t, 2, 1> int16_t2x1;
+typedef matrix<int16_t, 2, 2> int16_t2x2;
+typedef matrix<int16_t, 2, 3> int16_t2x3;
+typedef matrix<int16_t, 2, 4> int16_t2x4;
+typedef matrix<int16_t, 3, 1> int16_t3x1;
+typedef matrix<int16_t, 3, 2> int16_t3x2;
+typedef matrix<int16_t, 3, 3> int16_t3x3;
+typedef matrix<int16_t, 3, 4> int16_t3x4;
+typedef matrix<int16_t, 4, 1> int16_t4x1;
+typedef matrix<int16_t, 4, 2> int16_t4x2;
+typedef matrix<int16_t, 4, 3> int16_t4x3;
+typedef matrix<int16_t, 4, 4> int16_t4x4;
+typedef matrix<uint16_t, 1, 1> uint16_t1x1;
+typedef matrix<uint16_t, 1, 2> uint16_t1x2;
+typedef matrix<uint16_t, 1, 3> uint16_t1x3;
+typedef matrix<uint16_t, 1, 4> uint16_t1x4;
+typedef matrix<uint16_t, 2, 1> uint16_t2x1;
+typedef matrix<uint16_t, 2, 2> uint16_t2x2;
+typedef matrix<uint16_t, 2, 3> uint16_t2x3;
+typedef matrix<uint16_t, 2, 4> uint16_t2x4;
+typedef matrix<uint16_t, 3, 1> uint16_t3x1;
+typedef matrix<uint16_t, 3, 2> uint16_t3x2;
+typedef matrix<uint16_t, 3, 3> uint16_t3x3;
+typedef matrix<uint16_t, 3, 4> uint16_t3x4;
+typedef matrix<uint16_t, 4, 1> uint16_t4x1;
+typedef matrix<uint16_t, 4, 2> uint16_t4x2;
+typedef matrix<uint16_t, 4, 3> uint16_t4x3;
+typedef matrix<uint16_t, 4, 4> uint16_t4x4;
+#endif
+typedef matrix<int, 1, 1> int1x1;
+typedef matrix<int, 1, 2> int1x2;
+typedef matrix<int, 1, 3> int1x3;
+typedef matrix<int, 1, 4> int1x4;
+typedef matrix<int, 2, 1> int2x1;
+typedef matrix<int, 2, 2> int2x2;
+typedef matrix<int, 2, 3> int2x3;
+typedef matrix<int, 2, 4> int2x4;
+typedef matrix<int, 3, 1> int3x1;
+typedef matrix<int, 3, 2> int3x2;
+typedef matrix<int, 3, 3> int3x3;
+typedef matrix<int, 3, 4> int3x4;
+typedef matrix<int, 4, 1> int4x1;
+typedef matrix<int, 4, 2> int4x2;
+typedef matrix<int, 4, 3> int4x3;
+typedef matrix<int, 4, 4> int4x4;
+typedef matrix<uint, 1, 1> uint1x1;
+typedef matrix<uint, 1, 2> uint1x2;
+typedef matrix<uint, 1, 3> uint1x3;
+typedef matrix<uint, 1, 4> uint1x4;
+typedef matrix<uint, 2, 1> uint2x1;
+typedef matrix<uint, 2, 2> uint2x2;
+typedef matrix<uint, 2, 3> uint2x3;
+typedef matrix<uint, 2, 4> uint2x4;
+typedef matrix<uint, 3, 1> uint3x1;
+typedef matrix<uint, 3, 2> uint3x2;
+typedef matrix<uint, 3, 3> uint3x3;
+typedef matrix<uint, 3, 4> uint3x4;
+typedef matrix<uint, 4, 1> uint4x1;
+typedef matrix<uint, 4, 2> uint4x2;
+typedef matrix<uint, 4, 3> uint4x3;
+typedef matrix<uint, 4, 4> uint4x4;
+typedef matrix<int32_t, 1, 1> int32_t1x1;
+typedef matrix<int32_t, 1, 2> int32_t1x2;
+typedef matrix<int32_t, 1, 3> int32_t1x3;
+typedef matrix<int32_t, 1, 4> int32_t1x4;
+typedef matrix<int32_t, 2, 1> int32_t2x1;
+typedef matrix<int32_t, 2, 2> int32_t2x2;
+typedef matrix<int32_t, 2, 3> int32_t2x3;
+typedef matrix<int32_t, 2, 4> int32_t2x4;
+typedef matrix<int32_t, 3, 1> int32_t3x1;
+typedef matrix<int32_t, 3, 2> int32_t3x2;
+typedef matrix<int32_t, 3, 3> int32_t3x3;
+typedef matrix<int32_t, 3, 4> int32_t3x4;
+typedef matrix<int32_t, 4, 1> int32_t4x1;
+typedef matrix<int32_t, 4, 2> int32_t4x2;
+typedef matrix<int32_t, 4, 3> int32_t4x3;
+typedef matrix<int32_t, 4, 4> int32_t4x4;
+typedef matrix<uint32_t, 1, 1> uint32_t1x1;
+typedef matrix<uint32_t, 1, 2> uint32_t1x2;
+typedef matrix<uint32_t, 1, 3> uint32_t1x3;
+typedef matrix<uint32_t, 1, 4> uint32_t1x4;
+typedef matrix<uint32_t, 2, 1> uint32_t2x1;
+typedef matrix<uint32_t, 2, 2> uint32_t2x2;
+typedef matrix<uint32_t, 2, 3> uint32_t2x3;
+typedef matrix<uint32_t, 2, 4> uint32_t2x4;
+typedef matrix<uint32_t, 3, 1> uint32_t3x1;
+typedef matrix<uint32_t, 3, 2> uint32_t3x2;
+typedef matrix<uint32_t, 3, 3> uint32_t3x3;
+typedef matrix<uint32_t, 3, 4> uint32_t3x4;
+typedef matrix<uint32_t, 4, 1> uint32_t4x1;
+typedef matrix<uint32_t, 4, 2> uint32_t4x2;
+typedef matrix<uint32_t, 4, 3> uint32_t4x3;
+typedef matrix<uint32_t, 4, 4> uint32_t4x4;
+typedef matrix<int64_t, 1, 1> int64_t1x1;
+typedef matrix<int64_t, 1, 2> int64_t1x2;
+typedef matrix<int64_t, 1, 3> int64_t1x3;
+typedef matrix<int64_t, 1, 4> int64_t1x4;
+typedef matrix<int64_t, 2, 1> int64_t2x1;
+typedef matrix<int64_t, 2, 2> int64_t2x2;
+typedef matrix<int64_t, 2, 3> int64_t2x3;
+typedef matrix<int64_t, 2, 4> int64_t2x4;
+typedef matrix<int64_t, 3, 1> int64_t3x1;
+typedef matrix<int64_t, 3, 2> int64_t3x2;
+typedef matrix<int64_t, 3, 3> int64_t3x3;
+typedef matrix<int64_t, 3, 4> int64_t3x4;
+typedef matrix<int64_t, 4, 1> int64_t4x1;
+typedef matrix<int64_t, 4, 2> int64_t4x2;
+typedef matrix<int64_t, 4, 3> int64_t4x3;
+typedef matrix<int64_t, 4, 4> int64_t4x4;
+typedef matrix<uint64_t, 1, 1> uint64_t1x1;
+typedef matrix<uint64_t, 1, 2> uint64_t1x2;
+typedef matrix<uint64_t, 1, 3> uint64_t1x3;
+typedef matrix<uint64_t, 1, 4> uint64_t1x4;
+typedef matrix<uint64_t, 2, 1> uint64_t2x1;
+typedef matrix<uint64_t, 2, 2> uint64_t2x2;
+typedef matrix<uint64_t, 2, 3> uint64_t2x3;
+typedef matrix<uint64_t, 2, 4> uint64_t2x4;
+typedef matrix<uint64_t, 3, 1> uint64_t3x1;
+typedef matrix<uint64_t, 3, 2> uint64_t3x2;
+typedef matrix<uint64_t, 3, 3> uint64_t3x3;
+typedef matrix<uint64_t, 3, 4> uint64_t3x4;
+typedef matrix<uint64_t, 4, 1> uint64_t4x1;
+typedef matrix<uint64_t, 4, 2> uint64_t4x2;
+typedef matrix<uint64_t, 4, 3> uint64_t4x3;
+typedef matrix<uint64_t, 4, 4> uint64_t4x4;
+
+typedef matrix<half, 1, 1> half1x1;
+typedef matrix<half, 1, 2> half1x2;
+typedef matrix<half, 1, 3> half1x3;
+typedef matrix<half, 1, 4> half1x4;
+typedef matrix<half, 2, 1> half2x1;
+typedef matrix<half, 2, 2> half2x2;
+typedef matrix<half, 2, 3> half2x3;
+typedef matrix<half, 2, 4> half2x4;
+typedef matrix<half, 3, 1> half3x1;
+typedef matrix<half, 3, 2> half3x2;
+typedef matrix<half, 3, 3> half3x3;
+typedef matrix<half, 3, 4> half3x4;
+typedef matrix<half, 4, 1> half4x1;
+typedef matrix<half, 4, 2> half4x2;
+typedef matrix<half, 4, 3> half4x3;
+typedef matrix<half, 4, 4> half4x4;
+typedef matrix<float, 1, 1> float1x1;
+typedef matrix<float, 1, 2> float1x2;
+typedef matrix<float, 1, 3> float1x3;
+typedef matrix<float, 1, 4> float1x4;
+typedef matrix<float, 2, 1> float2x1;
+typedef matrix<float, 2, 2> float2x2;
+typedef matrix<float, 2, 3> float2x3;
+typedef matrix<float, 2, 4> float2x4;
+typedef matrix<float, 3, 1> float3x1;
+typedef matrix<float, 3, 2> float3x2;
+typedef matrix<float, 3, 3> float3x3;
+typedef matrix<float, 3, 4> float3x4;
+typedef matrix<float, 4, 1> float4x1;
+typedef matrix<float, 4, 2> float4x2;
+typedef matrix<float, 4, 3> float4x3;
+typedef matrix<float, 4, 4> float4x4;
+typedef matrix<double, 1, 1> double1x1;
+typedef matrix<double, 1, 2> double1x2;
+typedef matrix<double, 1, 3> double1x3;
+typedef matrix<double, 1, 4> double1x4;
+typedef matrix<double, 2, 1> double2x1;
+typedef matrix<double, 2, 2> double2x2;
+typedef matrix<double, 2, 3> double2x3;
+typedef matrix<double, 2, 4> double2x4;
+typedef matrix<double, 3, 1> double3x1;
+typedef matrix<double, 3, 2> double3x2;
+typedef matrix<double, 3, 3> double3x3;
+typedef matrix<double, 3, 4> double3x4;
+typedef matrix<double, 4, 1> double4x1;
+typedef matrix<double, 4, 2> double4x2;
+typedef matrix<double, 4, 3> double4x3;
+typedef matrix<double, 4, 4> double4x4;
+
+#ifdef __HLSL_ENABLE_16_BIT
+typedef matrix<float16_t, 1, 1> float16_t1x1;
+typedef matrix<float16_t, 1, 2> float16_t1x2;
+typedef matrix<float16_t, 1, 3> float16_t1x3;
+typedef matrix<float16_t, 1, 4> float16_t1x4;
+typedef matrix<float16_t, 2, 1> float16_t2x1;
+typedef matrix<float16_t, 2, 2> float16_t2x2;
+typedef matrix<float16_t, 2, 3> float16_t2x3;
+typedef matrix<float16_t, 2, 4> float16_t2x4;
+typedef matrix<float16_t, 3, 1> float16_t3x1;
+typedef matrix<float16_t, 3, 2> float16_t3x2;
+typedef matrix<float16_t, 3, 3> float16_t3x3;
+typedef matrix<float16_t, 3, 4> float16_t3x4;
+typedef matrix<float16_t, 4, 1> float16_t4x1;
+typedef matrix<float16_t, 4, 2> float16_t4x2;
+typedef matrix<float16_t, 4, 3> float16_t4x3;
+typedef matrix<float16_t, 4, 4> float16_t4x4;
+#endif
+
+typedef matrix<float32_t, 1, 1> float32_t1x1;
+typedef matrix<float32_t, 1, 2> float32_t1x2;
+typedef matrix<float32_t, 1, 3> float32_t1x3;
+typedef matrix<float32_t, 1, 4> float32_t1x4;
+typedef matrix<float32_t, 2, 1> float32_t2x1;
+typedef matrix<float32_t, 2, 2> float32_t2x2;
+typedef matrix<float32_t, 2, 3> float32_t2x3;
+typedef matrix<float32_t, 2, 4> float32_t2x4;
+typedef matrix<float32_t, 3, 1> float32_t3x1;
+typedef matrix<float32_t, 3, 2> float32_t3x2;
+typedef matrix<float32_t, 3, 3> float32_t3x3;
+typedef matrix<float32_t, 3, 4> float32_t3x4;
+typedef matrix<float32_t, 4, 1> float32_t4x1;
+typedef matrix<float32_t, 4, 2> float32_t4x2;
+typedef matrix<float32_t, 4, 3> float32_t4x3;
+typedef matrix<float32_t, 4, 4> float32_t4x4;
+typedef matrix<float64_t, 1, 1> float64_t1x1;
+typedef matrix<float64_t, 1, 2> float64_t1x2;
+typedef matrix<float64_t, 1, 3> float64_t1x3;
+typedef matrix<float64_t, 1, 4> float64_t1x4;
+typedef matrix<float64_t, 2, 1> float64_t2x1;
+typedef matrix<float64_t, 2, 2> float64_t2x2;
+typedef matrix<float64_t, 2, 3> float64_t2x3;
+typedef matrix<float64_t, 2, 4> float64_t2x4;
+typedef matrix<float64_t, 3, 1> float64_t3x1;
+typedef matrix<float64_t, 3, 2> float64_t3x2;
+typedef matrix<float64_t, 3, 3> float64_t3x3;
+typedef matrix<float64_t, 3, 4> float64_t3x4;
+typedef matrix<float64_t, 4, 1> float64_t4x1;
+typedef matrix<float64_t, 4, 2> float64_t4x2;
+typedef matrix<float64_t, 4, 3> float64_t4x3;
+typedef matrix<float64_t, 4, 4> float64_t4x4;
+
 } // namespace hlsl
 
 #endif //_HLSL_HLSL_BASIC_TYPES_H_
diff --git a/clang/lib/Sema/HLSLExternalSemaSource.cpp b/clang/lib/Sema/HLSLExternalSemaSource.cpp
index 2913d16fca4823..d1a53d2ad88864 100644
--- a/clang/lib/Sema/HLSLExternalSemaSource.cpp
+++ b/clang/lib/Sema/HLSLExternalSemaSource.cpp
@@ -472,8 +472,81 @@ void HLSLExternalSemaSource::defineHLSLVectorAlias() {
   HLSLNamespace->addDecl(Template);
 }
 
+void HLSLExternalSemaSource::defineHLSLMatrixAlias() {
+  ASTContext &AST = SemaPtr->getASTContext();
+
+  llvm::SmallVector<NamedDecl *> TemplateParams;
+
+  auto *TypeParam = TemplateTypeParmDecl::Create(
+      AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 0,
+      &AST.Idents.get("element", tok::TokenKind::identifier), false, false);
+  TypeParam->setDefaultArgument(
+      AST, SemaPtr->getTrivialTemplateArgumentLoc(
+               TemplateArgument(AST.FloatTy), QualType(), SourceLocation()));
+
+  TemplateParams.emplace_back(TypeParam);
+
+  // these should be 64 bit to be consistent with other clang matrices.
+  auto *RowsParam = NonTypeTemplateParmDecl::Create(
+      AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 1,
+      &AST.Idents.get("rows_count", tok::TokenKind::identifier), AST.IntTy,
+      false, AST.getTrivialTypeSourceInfo(AST.IntTy));
+  llvm::APInt RVal(AST.getIntWidth(AST.IntTy), 4);
+  TemplateArgument RDefault(AST, llvm::APSInt(std::move(RVal)), AST.IntTy,
+                           /*IsDefaulted=*/true);
+  RowsParam->setDefaultArgument(
+      AST, SemaPtr->getTrivialTemplateArgumentLoc(RDefault, AST.IntTy,
+                                                  SourceLocation(), RowsParam));
+  TemplateParams.emplace_back(RowsParam);
+
+  auto *ColsParam = NonTypeTemplateParmDecl::Create(
+      AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 2,
+      &AST.Idents.get("cols_count", tok::TokenKind::identifier), AST.IntTy,
+      false, AST.getTrivialTypeSourceInfo(AST.IntTy));
+  llvm::APInt CVal(AST.getIntWidth(AST.IntTy), 4);
+  TemplateArgument CDefault(AST, llvm::APSInt(std::move(CVal)), AST.IntTy,
+                           /*IsDefaulted=*/true);
+  ColsParam->setDefaultArgument(
+      AST, SemaPtr->getTrivialTemplateArgumentLoc(CDefault, AST.IntTy,
+                                                  SourceLocation(), ColsParam));
+  TemplateParams.emplace_back(RowsParam);
+
+  auto *ParamList =
+      TemplateParameterList::Create(AST, SourceLocation(), SourceLocation(),
+                                    TemplateParams, SourceLocation(), nullptr);
+
+  IdentifierInfo &II = AST.Idents.get("matrix", tok::TokenKind::identifier);
+
+  QualType AliasType = AST.getDependentSizedMatrixType(
+      AST.getTemplateTypeParmType(0, 0, false, TypeParam),
+      DeclRefExpr::Create(
+          AST, NestedNameSpecifierLoc(), SourceLocation(), RowsParam, false,
+          DeclarationNameInfo(RowsParam->getDeclName(), SourceLocation()),
+          AST.IntTy, VK_LValue),
+      DeclRefExpr::Create(
+          AST, NestedNameSpecifierLoc(), SourceLocation(), ColsParam, false,
+          DeclarationNameInfo(ColsParam->getDeclName(), SourceLocation()),
+          AST.IntTy, VK_LValue),
+      SourceLocation());
+
+  auto *Record = TypeAliasDecl::Create(AST, HLSLNamespace, SourceLocation(),
+                                       SourceLocation(), &II,
+                                       AST.getTrivialTypeSourceInfo(AliasType));
+  Record->setImplicit(true);
+
+  auto *Template =
+      TypeAliasTemplateDecl::Create(AST, HLSLNamespace, SourceLocation(),
+                                    Record->getIdentifier(), ParamList, Record);
+
+  Record->setDescribedAliasTemplate(Template);
+  Template->setImplicit(true);
+  Template->setLexicalDeclContext(Record->getDeclContext());
+  HLSLNamespace->addDecl(Template);
+}
+
 void HLSLExternalSemaSource::defineTrivialHLSLTypes() {
   defineHLSLVectorAlias();
+  defineHLSLMatrixAlias();
 }
 
 /// Set up common members and attributes for buffer types
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index c44fc9c4194ca4..9213b4d95a70d9 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -2447,7 +2447,7 @@ QualType Sema::BuildExtVectorType(QualType T, Expr *ArraySize,
 
 QualType Sema::BuildMatrixType(QualType ElementTy, Expr *NumRows, Expr *NumCols,
                                SourceLocation AttrLoc) {
-  assert(Context.getLangOpts().MatrixTypes &&
+  assert((getLangOpts().MatrixTypes || getLangOpts().HLSL) &&
          "Should never build a matrix type when it is disabled");
 
   // Check element type, if it is not dependent.
diff --git a/clang/test/AST/HLSL/matrix-alias.hlsl b/clang/test/AST/HLSL/matrix-alias.hlsl
new file mode 100644
index 00000000000000..afac2cfed7604b
--- /dev/null
+++ b/clang/test/AST/HLSL/matrix-alias.hlsl
@@ -0,0 +1,49 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -ast-dump -o - %s | FileCheck %s
+
+// Test that matrix aliases are set up properly for HLSL
+
+// CHECK: NamespaceDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit hlsl
+// CHECK-NEXT: TypeAliasTemplateDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit vector
+// CHECK-NEXT: TemplateTypeParmDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> class depth 0 index 0 element
+// CHECK-NEXT: TemplateArgument type 'float'
+// CHECK-NEXT: BuiltinType 0x{{[0-9a-fA-F]+}} 'float'
+// CHECK-NEXT: NonTypeTemplateParmDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> 'int' depth 0 index 1 element_count
+// CHECK-NEXT: TemplateArgument expr
+// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <<invalid sloc>> 'int' 4
+// CHECK-NEXT: TypeAliasDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit vector 'vector<element, element_count>'
+// CHECK-NEXT: DependentSizedExtVectorType 0x{{[0-9a-fA-F]+}} 'vector<element, element_count>' dependent <invalid sloc>
+// CHECK-NEXT: TemplateTypeParmType 0x{{[0-9a-fA-F]+}} 'element' dependent depth 0 index 0
+// CHECK-NEXT: TemplateTypeParm 0x{{[0-9a-fA-F]+}} 'element'
+// CHECK-NEXT: DeclRefExpr 0x{{[0-9a-fA-F]+}} <<invalid sloc>> 'int' lvalue
+// CHECK-SAME: NonTypeTemplateParm 0x{{[0-9a-fA-F]+}} 'element_count' 'int'
+
+// Make sure we got a using directive at the end.
+// CHECK: UsingDirectiveDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> Namespace 0x{{[0-9a-fA-F]+}} 'hlsl'
+
+[numthreads(1,1,1)]
+int entry() {
+  // Verify that the alias is generated inside the hlsl namespace.
+  hlsl::matrix<float, 2, 2> Mat2x2;
+
+  // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:26:3, col:35>
+  // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:29> col:29 Mat2x2 'hlsl::matrix<float, 2, 2>'
+
+  // Verify that you don't need to specify the namespace.
+  matrix<int, 2, 2> Vec2x2a;
+
+  // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:32:3, col:28>
+  // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:21> col:21 Vec2x2a 'matrix<int, 2, 2>'
+
+  // Build a bigger matrix.
+  matrix<double, 4, 4> Mat4x4;
+
+  // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:38:3, col:30>
+  // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:24> col:24 Mat4x4 'matrix<double, 4, 4>'
+
+  // Verify that the implicit arguments generate the correct type.
+  matrix<> ImpM...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/111415


More information about the cfe-commits mailing list