[clang] 6ac40aa - [HLSL] Add support for the HLSL matrix type (#159446)

via cfe-commits cfe-commits at lists.llvm.org
Tue Sep 23 10:02:22 PDT 2025


Author: Farzon Lotfi
Date: 2025-09-23T13:02:17-04:00
New Revision: 6ac40aaaf69b710b10873b82593d4315a0838726

URL: https://github.com/llvm/llvm-project/commit/6ac40aaaf69b710b10873b82593d4315a0838726
DIFF: https://github.com/llvm/llvm-project/commit/6ac40aaaf69b710b10873b82593d4315a0838726.diff

LOG: [HLSL] Add support for the HLSL matrix type (#159446)

fixes #109839

This change is really simple. It creates a matrix alias that will let
HLSL use the existing clang `matrix_type` infra.

The only additional change was to add explict alias for the typed
dimensions of 1-4 inclusive matricies available in HLSL.

Testing therefore is limited to exercising the alias, sema errors, and
basic codegen.
future work will add things like constructors and accessors.

The main difference in this attempt is the type printer and less of an
emphasis on tests where things overlap with existing `matrix_type`
testing like cast behavior.

Added: 
    clang/test/AST/HLSL/matrix-alias.hlsl
    clang/test/CodeGenHLSL/builtins/transpose-builtin.hlsl
    clang/test/SemaHLSL/BuiltIns/matrix-basic_types-errors.hlsl
    clang/test/SemaHLSL/BuiltIns/matrix-errors.hlsl

Modified: 
    clang/include/clang/Driver/Options.td
    clang/include/clang/Sema/HLSLExternalSemaSource.h
    clang/lib/AST/TypePrinter.cpp
    clang/lib/Headers/hlsl/hlsl_basic_types.h
    clang/lib/Sema/HLSLExternalSemaSource.cpp
    clang/lib/Sema/SemaHLSL.cpp

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td
index 16e1c396fedbe..77f19a240a7f9 100644
--- a/clang/include/clang/Driver/Options.td
+++ b/clang/include/clang/Driver/Options.td
@@ -4587,7 +4587,7 @@ defm ptrauth_block_descriptor_pointers : OptInCC1FFlag<"ptrauth-block-descriptor
 def fenable_matrix : Flag<["-"], "fenable-matrix">, Group<f_Group>,
     Visibility<[ClangOption, CC1Option]>,
     HelpText<"Enable matrix data type and related builtin functions">,
-    MarshallingInfoFlag<LangOpts<"MatrixTypes">>;
+    MarshallingInfoFlag<LangOpts<"MatrixTypes">, hlsl.KeyPath>;
 
 defm raw_string_literals : BoolFOption<"raw-string-literals",
     LangOpts<"RawStringLiterals">, Default<std#".hasRawStringLiterals()">,

diff  --git a/clang/include/clang/Sema/HLSLExternalSemaSource.h b/clang/include/clang/Sema/HLSLExternalSemaSource.h
index d93fb8c8eef6b..049fc7b8fe3f2 100644
--- a/clang/include/clang/Sema/HLSLExternalSemaSource.h
+++ b/clang/include/clang/Sema/HLSLExternalSemaSource.h
@@ -44,6 +44,7 @@ class HLSLExternalSemaSource : public ExternalSemaSource {
 private:
   void defineTrivialHLSLTypes();
   void defineHLSLVectorAlias();
+  void defineHLSLMatrixAlias();
   void defineHLSLTypesWithForwardDeclarations();
   void onCompletion(CXXRecordDecl *Record, CompletionFunction Fn);
 };

diff  --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp
index cd59678d67f2f..f3448af5f8f50 100644
--- a/clang/lib/AST/TypePrinter.cpp
+++ b/clang/lib/AST/TypePrinter.cpp
@@ -846,16 +846,45 @@ void TypePrinter::printExtVectorAfter(const ExtVectorType *T, raw_ostream &OS) {
   }
 }
 
-void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T,
-                                            raw_ostream &OS) {
-  printBefore(T->getElementType(), OS);
-  OS << " __attribute__((matrix_type(";
+static void printDims(const ConstantMatrixType *T, raw_ostream &OS) {
   OS << T->getNumRows() << ", " << T->getNumColumns();
+}
+
+static void printHLSLMatrixBefore(TypePrinter &TP, const ConstantMatrixType *T,
+                                  raw_ostream &OS) {
+  OS << "matrix<";
+  TP.printBefore(T->getElementType(), OS);
+}
+
+static void printHLSLMatrixAfter(const ConstantMatrixType *T, raw_ostream &OS) {
+  OS << ", ";
+  printDims(T, OS);
+  OS << ">";
+}
+
+static void printClangMatrixBefore(TypePrinter &TP, const ConstantMatrixType *T,
+                                   raw_ostream &OS) {
+  TP.printBefore(T->getElementType(), OS);
+  OS << " __attribute__((matrix_type(";
+  printDims(T, OS);
   OS << ")))";
 }
 
+void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T,
+                                            raw_ostream &OS) {
+  if (Policy.UseHLSLTypes) {
+    printHLSLMatrixBefore(*this, T, OS);
+    return;
+  }
+  printClangMatrixBefore(*this, T, OS);
+}
+
 void TypePrinter::printConstantMatrixAfter(const ConstantMatrixType *T,
                                            raw_ostream &OS) {
+  if (Policy.UseHLSLTypes) {
+    printHLSLMatrixAfter(T, OS);
+    return;
+  }
   printAfter(T->getElementType(), OS);
 }
 

diff  --git a/clang/lib/Headers/hlsl/hlsl_basic_types.h b/clang/lib/Headers/hlsl/hlsl_basic_types.h
index eff94e0d7f950..fc1e265067714 100644
--- a/clang/lib/Headers/hlsl/hlsl_basic_types.h
+++ b/clang/lib/Headers/hlsl/hlsl_basic_types.h
@@ -115,6 +115,239 @@ typedef vector<float64_t, 2> float64_t2;
 typedef vector<float64_t, 3> float64_t3;
 typedef vector<float64_t, 4> float64_t4;
 
+#ifdef __HLSL_ENABLE_16_BIT
+typedef matrix<int16_t, 1, 1> int16_t1x1;
+typedef matrix<int16_t, 1, 2> int16_t1x2;
+typedef matrix<int16_t, 1, 3> int16_t1x3;
+typedef matrix<int16_t, 1, 4> int16_t1x4;
+typedef matrix<int16_t, 2, 1> int16_t2x1;
+typedef matrix<int16_t, 2, 2> int16_t2x2;
+typedef matrix<int16_t, 2, 3> int16_t2x3;
+typedef matrix<int16_t, 2, 4> int16_t2x4;
+typedef matrix<int16_t, 3, 1> int16_t3x1;
+typedef matrix<int16_t, 3, 2> int16_t3x2;
+typedef matrix<int16_t, 3, 3> int16_t3x3;
+typedef matrix<int16_t, 3, 4> int16_t3x4;
+typedef matrix<int16_t, 4, 1> int16_t4x1;
+typedef matrix<int16_t, 4, 2> int16_t4x2;
+typedef matrix<int16_t, 4, 3> int16_t4x3;
+typedef matrix<int16_t, 4, 4> int16_t4x4;
+typedef matrix<uint16_t, 1, 1> uint16_t1x1;
+typedef matrix<uint16_t, 1, 2> uint16_t1x2;
+typedef matrix<uint16_t, 1, 3> uint16_t1x3;
+typedef matrix<uint16_t, 1, 4> uint16_t1x4;
+typedef matrix<uint16_t, 2, 1> uint16_t2x1;
+typedef matrix<uint16_t, 2, 2> uint16_t2x2;
+typedef matrix<uint16_t, 2, 3> uint16_t2x3;
+typedef matrix<uint16_t, 2, 4> uint16_t2x4;
+typedef matrix<uint16_t, 3, 1> uint16_t3x1;
+typedef matrix<uint16_t, 3, 2> uint16_t3x2;
+typedef matrix<uint16_t, 3, 3> uint16_t3x3;
+typedef matrix<uint16_t, 3, 4> uint16_t3x4;
+typedef matrix<uint16_t, 4, 1> uint16_t4x1;
+typedef matrix<uint16_t, 4, 2> uint16_t4x2;
+typedef matrix<uint16_t, 4, 3> uint16_t4x3;
+typedef matrix<uint16_t, 4, 4> uint16_t4x4;
+#endif
+
+typedef matrix<int, 1, 1> int1x1;
+typedef matrix<int, 1, 2> int1x2;
+typedef matrix<int, 1, 3> int1x3;
+typedef matrix<int, 1, 4> int1x4;
+typedef matrix<int, 2, 1> int2x1;
+typedef matrix<int, 2, 2> int2x2;
+typedef matrix<int, 2, 3> int2x3;
+typedef matrix<int, 2, 4> int2x4;
+typedef matrix<int, 3, 1> int3x1;
+typedef matrix<int, 3, 2> int3x2;
+typedef matrix<int, 3, 3> int3x3;
+typedef matrix<int, 3, 4> int3x4;
+typedef matrix<int, 4, 1> int4x1;
+typedef matrix<int, 4, 2> int4x2;
+typedef matrix<int, 4, 3> int4x3;
+typedef matrix<int, 4, 4> int4x4;
+typedef matrix<uint, 1, 1> uint1x1;
+typedef matrix<uint, 1, 2> uint1x2;
+typedef matrix<uint, 1, 3> uint1x3;
+typedef matrix<uint, 1, 4> uint1x4;
+typedef matrix<uint, 2, 1> uint2x1;
+typedef matrix<uint, 2, 2> uint2x2;
+typedef matrix<uint, 2, 3> uint2x3;
+typedef matrix<uint, 2, 4> uint2x4;
+typedef matrix<uint, 3, 1> uint3x1;
+typedef matrix<uint, 3, 2> uint3x2;
+typedef matrix<uint, 3, 3> uint3x3;
+typedef matrix<uint, 3, 4> uint3x4;
+typedef matrix<uint, 4, 1> uint4x1;
+typedef matrix<uint, 4, 2> uint4x2;
+typedef matrix<uint, 4, 3> uint4x3;
+typedef matrix<uint, 4, 4> uint4x4;
+typedef matrix<int32_t, 1, 1> int32_t1x1;
+typedef matrix<int32_t, 1, 2> int32_t1x2;
+typedef matrix<int32_t, 1, 3> int32_t1x3;
+typedef matrix<int32_t, 1, 4> int32_t1x4;
+typedef matrix<int32_t, 2, 1> int32_t2x1;
+typedef matrix<int32_t, 2, 2> int32_t2x2;
+typedef matrix<int32_t, 2, 3> int32_t2x3;
+typedef matrix<int32_t, 2, 4> int32_t2x4;
+typedef matrix<int32_t, 3, 1> int32_t3x1;
+typedef matrix<int32_t, 3, 2> int32_t3x2;
+typedef matrix<int32_t, 3, 3> int32_t3x3;
+typedef matrix<int32_t, 3, 4> int32_t3x4;
+typedef matrix<int32_t, 4, 1> int32_t4x1;
+typedef matrix<int32_t, 4, 2> int32_t4x2;
+typedef matrix<int32_t, 4, 3> int32_t4x3;
+typedef matrix<int32_t, 4, 4> int32_t4x4;
+typedef matrix<uint32_t, 1, 1> uint32_t1x1;
+typedef matrix<uint32_t, 1, 2> uint32_t1x2;
+typedef matrix<uint32_t, 1, 3> uint32_t1x3;
+typedef matrix<uint32_t, 1, 4> uint32_t1x4;
+typedef matrix<uint32_t, 2, 1> uint32_t2x1;
+typedef matrix<uint32_t, 2, 2> uint32_t2x2;
+typedef matrix<uint32_t, 2, 3> uint32_t2x3;
+typedef matrix<uint32_t, 2, 4> uint32_t2x4;
+typedef matrix<uint32_t, 3, 1> uint32_t3x1;
+typedef matrix<uint32_t, 3, 2> uint32_t3x2;
+typedef matrix<uint32_t, 3, 3> uint32_t3x3;
+typedef matrix<uint32_t, 3, 4> uint32_t3x4;
+typedef matrix<uint32_t, 4, 1> uint32_t4x1;
+typedef matrix<uint32_t, 4, 2> uint32_t4x2;
+typedef matrix<uint32_t, 4, 3> uint32_t4x3;
+typedef matrix<uint32_t, 4, 4> uint32_t4x4;
+typedef matrix<int64_t, 1, 1> int64_t1x1;
+typedef matrix<int64_t, 1, 2> int64_t1x2;
+typedef matrix<int64_t, 1, 3> int64_t1x3;
+typedef matrix<int64_t, 1, 4> int64_t1x4;
+typedef matrix<int64_t, 2, 1> int64_t2x1;
+typedef matrix<int64_t, 2, 2> int64_t2x2;
+typedef matrix<int64_t, 2, 3> int64_t2x3;
+typedef matrix<int64_t, 2, 4> int64_t2x4;
+typedef matrix<int64_t, 3, 1> int64_t3x1;
+typedef matrix<int64_t, 3, 2> int64_t3x2;
+typedef matrix<int64_t, 3, 3> int64_t3x3;
+typedef matrix<int64_t, 3, 4> int64_t3x4;
+typedef matrix<int64_t, 4, 1> int64_t4x1;
+typedef matrix<int64_t, 4, 2> int64_t4x2;
+typedef matrix<int64_t, 4, 3> int64_t4x3;
+typedef matrix<int64_t, 4, 4> int64_t4x4;
+typedef matrix<uint64_t, 1, 1> uint64_t1x1;
+typedef matrix<uint64_t, 1, 2> uint64_t1x2;
+typedef matrix<uint64_t, 1, 3> uint64_t1x3;
+typedef matrix<uint64_t, 1, 4> uint64_t1x4;
+typedef matrix<uint64_t, 2, 1> uint64_t2x1;
+typedef matrix<uint64_t, 2, 2> uint64_t2x2;
+typedef matrix<uint64_t, 2, 3> uint64_t2x3;
+typedef matrix<uint64_t, 2, 4> uint64_t2x4;
+typedef matrix<uint64_t, 3, 1> uint64_t3x1;
+typedef matrix<uint64_t, 3, 2> uint64_t3x2;
+typedef matrix<uint64_t, 3, 3> uint64_t3x3;
+typedef matrix<uint64_t, 3, 4> uint64_t3x4;
+typedef matrix<uint64_t, 4, 1> uint64_t4x1;
+typedef matrix<uint64_t, 4, 2> uint64_t4x2;
+typedef matrix<uint64_t, 4, 3> uint64_t4x3;
+typedef matrix<uint64_t, 4, 4> uint64_t4x4;
+
+typedef matrix<half, 1, 1> half1x1;
+typedef matrix<half, 1, 2> half1x2;
+typedef matrix<half, 1, 3> half1x3;
+typedef matrix<half, 1, 4> half1x4;
+typedef matrix<half, 2, 1> half2x1;
+typedef matrix<half, 2, 2> half2x2;
+typedef matrix<half, 2, 3> half2x3;
+typedef matrix<half, 2, 4> half2x4;
+typedef matrix<half, 3, 1> half3x1;
+typedef matrix<half, 3, 2> half3x2;
+typedef matrix<half, 3, 3> half3x3;
+typedef matrix<half, 3, 4> half3x4;
+typedef matrix<half, 4, 1> half4x1;
+typedef matrix<half, 4, 2> half4x2;
+typedef matrix<half, 4, 3> half4x3;
+typedef matrix<half, 4, 4> half4x4;
+typedef matrix<float, 1, 1> float1x1;
+typedef matrix<float, 1, 2> float1x2;
+typedef matrix<float, 1, 3> float1x3;
+typedef matrix<float, 1, 4> float1x4;
+typedef matrix<float, 2, 1> float2x1;
+typedef matrix<float, 2, 2> float2x2;
+typedef matrix<float, 2, 3> float2x3;
+typedef matrix<float, 2, 4> float2x4;
+typedef matrix<float, 3, 1> float3x1;
+typedef matrix<float, 3, 2> float3x2;
+typedef matrix<float, 3, 3> float3x3;
+typedef matrix<float, 3, 4> float3x4;
+typedef matrix<float, 4, 1> float4x1;
+typedef matrix<float, 4, 2> float4x2;
+typedef matrix<float, 4, 3> float4x3;
+typedef matrix<float, 4, 4> float4x4;
+typedef matrix<double, 1, 1> double1x1;
+typedef matrix<double, 1, 2> double1x2;
+typedef matrix<double, 1, 3> double1x3;
+typedef matrix<double, 1, 4> double1x4;
+typedef matrix<double, 2, 1> double2x1;
+typedef matrix<double, 2, 2> double2x2;
+typedef matrix<double, 2, 3> double2x3;
+typedef matrix<double, 2, 4> double2x4;
+typedef matrix<double, 3, 1> double3x1;
+typedef matrix<double, 3, 2> double3x2;
+typedef matrix<double, 3, 3> double3x3;
+typedef matrix<double, 3, 4> double3x4;
+typedef matrix<double, 4, 1> double4x1;
+typedef matrix<double, 4, 2> double4x2;
+typedef matrix<double, 4, 3> double4x3;
+typedef matrix<double, 4, 4> double4x4;
+
+#ifdef __HLSL_ENABLE_16_BIT
+typedef matrix<float16_t, 1, 1> float16_t1x1;
+typedef matrix<float16_t, 1, 2> float16_t1x2;
+typedef matrix<float16_t, 1, 3> float16_t1x3;
+typedef matrix<float16_t, 1, 4> float16_t1x4;
+typedef matrix<float16_t, 2, 1> float16_t2x1;
+typedef matrix<float16_t, 2, 2> float16_t2x2;
+typedef matrix<float16_t, 2, 3> float16_t2x3;
+typedef matrix<float16_t, 2, 4> float16_t2x4;
+typedef matrix<float16_t, 3, 1> float16_t3x1;
+typedef matrix<float16_t, 3, 2> float16_t3x2;
+typedef matrix<float16_t, 3, 3> float16_t3x3;
+typedef matrix<float16_t, 3, 4> float16_t3x4;
+typedef matrix<float16_t, 4, 1> float16_t4x1;
+typedef matrix<float16_t, 4, 2> float16_t4x2;
+typedef matrix<float16_t, 4, 3> float16_t4x3;
+typedef matrix<float16_t, 4, 4> float16_t4x4;
+#endif
+
+typedef matrix<float32_t, 1, 1> float32_t1x1;
+typedef matrix<float32_t, 1, 2> float32_t1x2;
+typedef matrix<float32_t, 1, 3> float32_t1x3;
+typedef matrix<float32_t, 1, 4> float32_t1x4;
+typedef matrix<float32_t, 2, 1> float32_t2x1;
+typedef matrix<float32_t, 2, 2> float32_t2x2;
+typedef matrix<float32_t, 2, 3> float32_t2x3;
+typedef matrix<float32_t, 2, 4> float32_t2x4;
+typedef matrix<float32_t, 3, 1> float32_t3x1;
+typedef matrix<float32_t, 3, 2> float32_t3x2;
+typedef matrix<float32_t, 3, 3> float32_t3x3;
+typedef matrix<float32_t, 3, 4> float32_t3x4;
+typedef matrix<float32_t, 4, 1> float32_t4x1;
+typedef matrix<float32_t, 4, 2> float32_t4x2;
+typedef matrix<float32_t, 4, 3> float32_t4x3;
+typedef matrix<float32_t, 4, 4> float32_t4x4;
+typedef matrix<float64_t, 1, 1> float64_t1x1;
+typedef matrix<float64_t, 1, 2> float64_t1x2;
+typedef matrix<float64_t, 1, 3> float64_t1x3;
+typedef matrix<float64_t, 1, 4> float64_t1x4;
+typedef matrix<float64_t, 2, 1> float64_t2x1;
+typedef matrix<float64_t, 2, 2> float64_t2x2;
+typedef matrix<float64_t, 2, 3> float64_t2x3;
+typedef matrix<float64_t, 2, 4> float64_t2x4;
+typedef matrix<float64_t, 3, 1> float64_t3x1;
+typedef matrix<float64_t, 3, 2> float64_t3x2;
+typedef matrix<float64_t, 3, 3> float64_t3x3;
+typedef matrix<float64_t, 3, 4> float64_t3x4;
+typedef matrix<float64_t, 4, 1> float64_t4x1;
+typedef matrix<float64_t, 4, 2> float64_t4x2;
+typedef matrix<float64_t, 4, 3> float64_t4x3;
+typedef matrix<float64_t, 4, 4> float64_t4x4;
+
 } // namespace hlsl
 
 #endif //_HLSL_HLSL_BASIC_TYPES_H_

diff  --git a/clang/lib/Sema/HLSLExternalSemaSource.cpp b/clang/lib/Sema/HLSLExternalSemaSource.cpp
index 781f0445d0b61..464922b6257b6 100644
--- a/clang/lib/Sema/HLSLExternalSemaSource.cpp
+++ b/clang/lib/Sema/HLSLExternalSemaSource.cpp
@@ -121,8 +121,110 @@ void HLSLExternalSemaSource::defineHLSLVectorAlias() {
   HLSLNamespace->addDecl(Template);
 }
 
+void HLSLExternalSemaSource::defineHLSLMatrixAlias() {
+  ASTContext &AST = SemaPtr->getASTContext();
+  llvm::SmallVector<NamedDecl *> TemplateParams;
+
+  auto *TypeParam = TemplateTypeParmDecl::Create(
+      AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 0,
+      &AST.Idents.get("element", tok::TokenKind::identifier), false, false);
+  TypeParam->setDefaultArgument(
+      AST, SemaPtr->getTrivialTemplateArgumentLoc(
+               TemplateArgument(AST.FloatTy), QualType(), SourceLocation()));
+
+  TemplateParams.emplace_back(TypeParam);
+
+  // these should be 64 bit to be consistent with other clang matrices.
+  auto *RowsParam = NonTypeTemplateParmDecl::Create(
+      AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 1,
+      &AST.Idents.get("rows_count", tok::TokenKind::identifier), AST.IntTy,
+      false, AST.getTrivialTypeSourceInfo(AST.IntTy));
+  llvm::APInt RVal(AST.getIntWidth(AST.IntTy), 4);
+  TemplateArgument RDefault(AST, llvm::APSInt(std::move(RVal)), AST.IntTy,
+                            /*IsDefaulted=*/true);
+  RowsParam->setDefaultArgument(
+      AST, SemaPtr->getTrivialTemplateArgumentLoc(RDefault, AST.IntTy,
+                                                  SourceLocation(), RowsParam));
+  TemplateParams.emplace_back(RowsParam);
+
+  auto *ColsParam = NonTypeTemplateParmDecl::Create(
+      AST, HLSLNamespace, SourceLocation(), SourceLocation(), 0, 2,
+      &AST.Idents.get("cols_count", tok::TokenKind::identifier), AST.IntTy,
+      false, AST.getTrivialTypeSourceInfo(AST.IntTy));
+  llvm::APInt CVal(AST.getIntWidth(AST.IntTy), 4);
+  TemplateArgument CDefault(AST, llvm::APSInt(std::move(CVal)), AST.IntTy,
+                            /*IsDefaulted=*/true);
+  ColsParam->setDefaultArgument(
+      AST, SemaPtr->getTrivialTemplateArgumentLoc(CDefault, AST.IntTy,
+                                                  SourceLocation(), ColsParam));
+  TemplateParams.emplace_back(ColsParam);
+
+  const unsigned MaxMatDim = 4;
+  auto *MaxRow = IntegerLiteral::Create(
+      AST, llvm::APInt(AST.getIntWidth(AST.IntTy), MaxMatDim), AST.IntTy,
+      SourceLocation());
+  auto *MaxCol = IntegerLiteral::Create(
+      AST, llvm::APInt(AST.getIntWidth(AST.IntTy), MaxMatDim), AST.IntTy,
+      SourceLocation());
+
+  auto *RowsRef = DeclRefExpr::Create(
+      AST, NestedNameSpecifierLoc(), SourceLocation(), RowsParam,
+      /*RefersToEnclosingVariableOrCapture*/ false,
+      DeclarationNameInfo(RowsParam->getDeclName(), SourceLocation()),
+      AST.IntTy, VK_LValue);
+  auto *ColsRef = DeclRefExpr::Create(
+      AST, NestedNameSpecifierLoc(), SourceLocation(), ColsParam,
+      /*RefersToEnclosingVariableOrCapture*/ false,
+      DeclarationNameInfo(ColsParam->getDeclName(), SourceLocation()),
+      AST.IntTy, VK_LValue);
+
+  auto *RowsLE = BinaryOperator::Create(AST, RowsRef, MaxRow, BO_LE, AST.BoolTy,
+                                        VK_PRValue, OK_Ordinary,
+                                        SourceLocation(), FPOptionsOverride());
+  auto *ColsLE = BinaryOperator::Create(AST, ColsRef, MaxCol, BO_LE, AST.BoolTy,
+                                        VK_PRValue, OK_Ordinary,
+                                        SourceLocation(), FPOptionsOverride());
+
+  auto *RequiresExpr = BinaryOperator::Create(
+      AST, RowsLE, ColsLE, BO_LAnd, AST.BoolTy, VK_PRValue, OK_Ordinary,
+      SourceLocation(), FPOptionsOverride());
+
+  auto *ParamList = TemplateParameterList::Create(
+      AST, SourceLocation(), SourceLocation(), TemplateParams, SourceLocation(),
+      RequiresExpr);
+
+  IdentifierInfo &II = AST.Idents.get("matrix", tok::TokenKind::identifier);
+
+  QualType AliasType = AST.getDependentSizedMatrixType(
+      AST.getTemplateTypeParmType(0, 0, false, TypeParam),
+      DeclRefExpr::Create(
+          AST, NestedNameSpecifierLoc(), SourceLocation(), RowsParam, false,
+          DeclarationNameInfo(RowsParam->getDeclName(), SourceLocation()),
+          AST.IntTy, VK_LValue),
+      DeclRefExpr::Create(
+          AST, NestedNameSpecifierLoc(), SourceLocation(), ColsParam, false,
+          DeclarationNameInfo(ColsParam->getDeclName(), SourceLocation()),
+          AST.IntTy, VK_LValue),
+      SourceLocation());
+
+  auto *Record = TypeAliasDecl::Create(AST, HLSLNamespace, SourceLocation(),
+                                       SourceLocation(), &II,
+                                       AST.getTrivialTypeSourceInfo(AliasType));
+  Record->setImplicit(true);
+
+  auto *Template =
+      TypeAliasTemplateDecl::Create(AST, HLSLNamespace, SourceLocation(),
+                                    Record->getIdentifier(), ParamList, Record);
+
+  Record->setDescribedAliasTemplate(Template);
+  Template->setImplicit(true);
+  Template->setLexicalDeclContext(Record->getDeclContext());
+  HLSLNamespace->addDecl(Template);
+}
+
 void HLSLExternalSemaSource::defineTrivialHLSLTypes() {
   defineHLSLVectorAlias();
+  defineHLSLMatrixAlias();
 }
 
 /// Set up common members and attributes for buffer types

diff  --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 7a61474e2a25c..b59d001d8dd14 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -3291,7 +3291,6 @@ static void BuildFlattenedTypeList(QualType BaseTy,
   while (!WorkList.empty()) {
     QualType T = WorkList.pop_back_val();
     T = T.getCanonicalType().getUnqualifiedType();
-    assert(!isa<MatrixType>(T) && "Matrix types not yet supported in HLSL");
     if (const auto *AT = dyn_cast<ConstantArrayType>(T)) {
       llvm::SmallVector<QualType, 16> ElementFields;
       // Generally I've avoided recursion in this algorithm, but arrays of

diff  --git a/clang/test/AST/HLSL/matrix-alias.hlsl b/clang/test/AST/HLSL/matrix-alias.hlsl
new file mode 100644
index 0000000000000..2758b6f0d202f
--- /dev/null
+++ b/clang/test/AST/HLSL/matrix-alias.hlsl
@@ -0,0 +1,49 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -ast-dump -o - %s | FileCheck %s
+
+// Test that matrix aliases are set up properly for HLSL
+
+// CHECK: NamespaceDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit hlsl
+// CHECK-NEXT: TypeAliasTemplateDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit vector
+// CHECK-NEXT: TemplateTypeParmDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> class depth 0 index 0 element
+// CHECK-NEXT: TemplateArgument type 'float'
+// CHECK-NEXT: BuiltinType 0x{{[0-9a-fA-F]+}} 'float'
+// CHECK-NEXT: NonTypeTemplateParmDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> 'int' depth 0 index 1 element_count
+// CHECK-NEXT: TemplateArgument expr
+// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <<invalid sloc>> 'int' 4
+// CHECK-NEXT: TypeAliasDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit vector 'vector<element, element_count>'
+// CHECK-NEXT: DependentSizedExtVectorType 0x{{[0-9a-fA-F]+}} 'vector<element, element_count>' dependent <invalid sloc>
+// CHECK-NEXT: TemplateTypeParmType 0x{{[0-9a-fA-F]+}} 'element' dependent depth 0 index 0
+// CHECK-NEXT: TemplateTypeParm 0x{{[0-9a-fA-F]+}} 'element'
+// CHECK-NEXT: DeclRefExpr 0x{{[0-9a-fA-F]+}} <<invalid sloc>> 'int' lvalue
+// CHECK-SAME: NonTypeTemplateParm 0x{{[0-9a-fA-F]+}} 'element_count' 'int'
+
+// Make sure we got a using directive at the end.
+// CHECK: UsingDirectiveDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> Namespace 0x{{[0-9a-fA-F]+}} 'hlsl'
+
+[numthreads(1,1,1)]
+int entry() {
+  // Verify that the alias is generated inside the hlsl namespace.
+  hlsl::matrix<float, 2, 2> Mat2x2f;
+
+  // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:26:3, col:36>
+  // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:29> col:29 Mat2x2f 'hlsl::matrix<float, 2, 2>'
+
+  // Verify that you don't need to specify the namespace.
+  matrix<int, 2, 2> Mat2x2i;
+
+  // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:32:3, col:28>
+  // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:21> col:21 Mat2x2i 'matrix<int, 2, 2>'
+
+  // Build a bigger matrix.
+  matrix<double, 4, 4> Mat4x4d;
+
+  // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:38:3, col:31>
+  // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:24> col:24 Mat4x4d 'matrix<double, 4, 4>'
+
+  // Verify that the implicit arguments generate the correct type.
+  matrix<> ImpMat4x4;
+
+  // CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:44:3, col:21>
+  // CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:12> col:12 ImpMat4x4 'matrix<>':'matrix<float, 4, 4>'
+  return 1;
+}

diff  --git a/clang/test/CodeGenHLSL/builtins/transpose-builtin.hlsl b/clang/test/CodeGenHLSL/builtins/transpose-builtin.hlsl
new file mode 100644
index 0000000000000..86aa7cd6985dd
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/transpose-builtin.hlsl
@@ -0,0 +1,30 @@
+// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+// NOTE: This test is only to confirm we can do codgen with the matrix alias.
+
+// CHECK-LABEL: define {{.*}}transpose_half_2x2
+void transpose_half_2x2(half2x2 a) {
+  // CHECK:        [[A:%.*]] = load <4 x half>, ptr {{.*}}, align 2
+  // CHECK-NEXT:   [[TRANS:%.*]] = call {{.*}}<4 x half> @llvm.matrix.transpose.v4f16(<4 x half> [[A]], i32 2, i32 2)
+  // CHECK-NEXT:   store <4 x half> [[TRANS]], ptr %a_t, align 2
+
+  half2x2 a_t = __builtin_matrix_transpose(a);
+}
+
+// CHECK-LABEL: define {{.*}}transpose_float_3x2
+void transpose_float_3x2(float3x2 a) {
+  // CHECK:        [[A:%.*]] = load <6 x float>, ptr {{.*}}, align 4
+  // CHECK-NEXT:   [[TRANS:%.*]] = call {{.*}}<6 x float> @llvm.matrix.transpose.v6f32(<6 x float> [[A]], i32 3, i32 2)
+  // CHECK-NEXT:   store <6 x float> [[TRANS]], ptr %a_t, align 4
+
+  float2x3 a_t = __builtin_matrix_transpose(a);
+}
+
+// CHECK-LABEL: define {{.*}}transpose_int_4x3
+void transpose_int_4x3(int4x3 a) {
+  // CHECK:         [[A:%.*]] = load <12 x i32>, ptr {{.*}}, align 4
+  // CHECK-NEXT:    [[TRANS:%.*]] = call <12 x i32> @llvm.matrix.transpose.v12i32(<12 x i32> [[A]], i32 4, i32 3)
+  // CHECK-NEXT:    store <12 x i32> [[TRANS]], ptr %a_t, align 4
+
+  int3x4 a_t = __builtin_matrix_transpose(a);
+}

diff  --git a/clang/test/SemaHLSL/BuiltIns/matrix-basic_types-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/matrix-basic_types-errors.hlsl
new file mode 100644
index 0000000000000..6f6b01bac829e
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/matrix-basic_types-errors.hlsl
@@ -0,0 +1,12 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
+
+uint64_t5x5 mat;
+// expected-error at -1  {{unknown type name 'uint64_t5x5'}}
+
+// Note: this one only fails because -fnative-half-type is not set
+uint16_t4x4 mat2;
+// expected-error at -1  {{unknown type name 'uint16_t4x4'}}
+
+matrix<int, 5, 5> mat3;
+// expected-error at -1 {{constraints not satisfied for alias template 'matrix' [with element = int, rows_count = 5, cols_count = 5]}}
+// expected-note@* {{because '5 <= 4' (5 <= 4) evaluated to false}}

diff  --git a/clang/test/SemaHLSL/BuiltIns/matrix-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/matrix-errors.hlsl
new file mode 100644
index 0000000000000..03751878bbb98
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/matrix-errors.hlsl
@@ -0,0 +1,29 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -fsyntax-only -verify %s
+
+// Some bad declarations
+hlsl::matrix ShouldWorkSomeday; // expected-error{{use of alias template 'hlsl::matrix' requires template arguments}}
+// expected-note@*:* {{template declaration from hidden source: template <class element = float, int rows_count = 4, int cols_count = 4> requires rows_count <= 4 && cols_count <= 4 using matrix = element __attribute__((matrix_type(rows_count, cols_count)))}}
+
+hlsl::matrix<1,1,1> BadMat; // expected-error{{template argument for template type parameter must be a type}}
+// expected-note@*:* {{template parameter from hidden source: class element = float}}
+
+hlsl::matrix<int, float,4> AnotherBadMat; // expected-error{{template argument for non-type template parameter must be an expression}}
+// expected-note@*:* {{template parameter from hidden source: int rows_count = 4}}
+
+hlsl::matrix<int, 2, 3, 2> YABV; // expected-error{{too many template arguments for alias template 'matrix'}}
+// expected-note@*:* {{template declaration from hidden source: template <class element = float, int rows_count = 4, int cols_count = 4> requires rows_count <= 4 && cols_count <= 4 using matrix = element __attribute__((matrix_type(rows_count, cols_count)))}}
+
+// This code is rejected by clang because clang puts the HLSL built-in types
+// into the HLSL namespace.
+namespace hlsl {
+  struct matrix {}; // expected-error {{redefinition of 'matrix'}}
+}
+
+// This code is rejected by dxc because dxc puts the HLSL built-in types
+// into the global space, but clang will allow it even though it will shadow the
+// matrix template.
+struct matrix {}; // expected-note {{candidate found by name lookup is 'matrix'}}
+
+matrix<int,2,2> matInt2x2; // expected-error {{reference to 'matrix' is ambiguous}}
+
+// expected-note@*:* {{candidate found by name lookup is 'hlsl::matrix'}}


        


More information about the cfe-commits mailing list