[clang] Support HLSL matrix initializers (PR #160960)
via cfe-commits
cfe-commits at lists.llvm.org
Fri Sep 26 15:30:41 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-hlsl
Author: Farzon Lotfi (farzonl)
<details>
<summary>Changes</summary>
fixes #<!-- -->159434
In HLSL matrices are `matrix_type` in all respects except that they support a constructor style syntax for initializing matrices. This change adds a translation of vector constructor arguments into initializer lists.
This supports the following HLSL syntax:
(1) HLSL matrices support constructor syntax
(2) HLSL matrices are expanded to constituate components in constructor
---
Patch is 47.17 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/160960.diff
6 Files Affected:
- (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+4-4)
- (modified) clang/include/clang/Sema/Initialization.h (+8-4)
- (modified) clang/lib/Sema/CheckExprLifetime.cpp (+1)
- (modified) clang/lib/Sema/SemaInit.cpp (+186-13)
- (added) clang/test/AST/HLSL/matrix-constructors.hlsl (+338)
- (added) clang/test/SemaHLSL/BuiltIns/matrix-constructors-errors.hlsl (+24)
``````````diff
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index bd0e53d3086b0..19e4c548d0208 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -6543,9 +6543,9 @@ def warn_extern_init : Warning<"'extern' variable has an initializer">,
def err_variable_object_no_init : Error<
"variable-sized object may not be initialized">;
def err_excess_initializers : Error<
- "excess elements in %select{array|vector|scalar|union|struct}0 initializer">;
+ "excess elements in %select{array|vector|matrix|scalar|union|struct}0 initializer">;
def ext_excess_initializers : ExtWarn<
- "excess elements in %select{array|vector|scalar|union|struct}0 initializer">,
+ "excess elements in %select{array|vector|matrix|scalar|union|struct}0 initializer">,
InGroup<ExcessInitializers>;
def err_excess_initializers_for_sizeless_type : Error<
"excess elements in initializer for indivisible sizeless type %0">;
@@ -11086,8 +11086,8 @@ def err_first_argument_to_cwsc_pdtor_call : Error<
def err_second_argument_to_cwsc_not_pointer : Error<
"second argument to __builtin_call_with_static_chain must be of pointer type">;
-def err_vector_incorrect_num_elements : Error<
- "%select{too many|too few}0 elements in vector %select{initialization|operand}3 (expected %1 elements, have %2)">;
+def err_tensor_incorrect_num_elements : Error<
+ "%select{too many|too few}0 elements in %select{vector|matrix}1 %select{initialization|operand}4 (expected %2 elements, have %3)">;
def err_altivec_empty_initializer : Error<"expected initializer">;
def err_vector_incorrect_bit_count : Error<
diff --git a/clang/include/clang/Sema/Initialization.h b/clang/include/clang/Sema/Initialization.h
index d7675ea153afb..865b6428f3081 100644
--- a/clang/include/clang/Sema/Initialization.h
+++ b/clang/include/clang/Sema/Initialization.h
@@ -91,6 +91,10 @@ class alignas(8) InitializedEntity {
/// or vector.
EK_VectorElement,
+ /// The entity being initialized is an element of a matrix.
+ /// or matrix.
+ EK_MatrixElement,
+
/// The entity being initialized is a field of block descriptor for
/// the copied-in c++ object.
EK_BlockElement,
@@ -205,8 +209,8 @@ class alignas(8) InitializedEntity {
/// virtual base.
llvm::PointerIntPair<const CXXBaseSpecifier *, 1> Base;
- /// When Kind == EK_ArrayElement, EK_VectorElement, or
- /// EK_ComplexElement, the index of the array or vector element being
+ /// When Kind == EK_ArrayElement, EK_VectorElement, or EK_MatrixElement,
+ /// or EK_ComplexElement, the index of the array or vector element being
/// initialized.
unsigned Index;
@@ -536,7 +540,7 @@ class alignas(8) InitializedEntity {
/// element's index.
unsigned getElementIndex() const {
assert(getKind() == EK_ArrayElement || getKind() == EK_VectorElement ||
- getKind() == EK_ComplexElement);
+ getKind() == EK_MatrixElement || getKind() == EK_ComplexElement);
return Index;
}
@@ -544,7 +548,7 @@ class alignas(8) InitializedEntity {
/// element, sets the element index.
void setElementIndex(unsigned Index) {
assert(getKind() == EK_ArrayElement || getKind() == EK_VectorElement ||
- getKind() == EK_ComplexElement);
+ getKind() == EK_MatrixElement || getKind() == EK_ComplexElement);
this->Index = Index;
}
diff --git a/clang/lib/Sema/CheckExprLifetime.cpp b/clang/lib/Sema/CheckExprLifetime.cpp
index e02e00231e58e..d647fbf007838 100644
--- a/clang/lib/Sema/CheckExprLifetime.cpp
+++ b/clang/lib/Sema/CheckExprLifetime.cpp
@@ -154,6 +154,7 @@ getEntityLifetime(const InitializedEntity *Entity,
case InitializedEntity::EK_LambdaToBlockConversionBlockElement:
case InitializedEntity::EK_LambdaCapture:
case InitializedEntity::EK_VectorElement:
+ case clang::InitializedEntity::EK_MatrixElement:
case InitializedEntity::EK_ComplexElement:
return {nullptr, LK_FullExpression};
diff --git a/clang/lib/Sema/SemaInit.cpp b/clang/lib/Sema/SemaInit.cpp
index c97129336736b..0f06d715600b3 100644
--- a/clang/lib/Sema/SemaInit.cpp
+++ b/clang/lib/Sema/SemaInit.cpp
@@ -17,6 +17,7 @@
#include "clang/AST/ExprCXX.h"
#include "clang/AST/ExprObjC.h"
#include "clang/AST/IgnoreExpr.h"
+#include "clang/AST/TypeBase.h"
#include "clang/AST/TypeLoc.h"
#include "clang/Basic/SourceManager.h"
#include "clang/Basic/Specifiers.h"
@@ -403,6 +404,9 @@ class InitListChecker {
unsigned &Index,
InitListExpr *StructuredList,
unsigned &StructuredIndex);
+ void CheckMatrixType(const InitializedEntity &Entity, InitListExpr *IList,
+ QualType DeclType, unsigned &Index,
+ InitListExpr *StructuredList, unsigned &StructuredIndex);
void CheckVectorType(const InitializedEntity &Entity,
InitListExpr *IList, QualType DeclType, unsigned &Index,
InitListExpr *StructuredList,
@@ -1003,7 +1007,8 @@ InitListChecker::FillInEmptyInitializations(const InitializedEntity &Entity,
return;
if (ElementEntity.getKind() == InitializedEntity::EK_ArrayElement ||
- ElementEntity.getKind() == InitializedEntity::EK_VectorElement)
+ ElementEntity.getKind() == InitializedEntity::EK_VectorElement ||
+ ElementEntity.getKind() == InitializedEntity::EK_MatrixElement)
ElementEntity.setElementIndex(Init);
if (Init >= NumInits && (ILE->hasArrayFiller() || SkipEmptyInitChecks))
@@ -1273,6 +1278,7 @@ static void warnBracedScalarInit(Sema &S, const InitializedEntity &Entity,
switch (Entity.getKind()) {
case InitializedEntity::EK_VectorElement:
+ case InitializedEntity::EK_MatrixElement:
case InitializedEntity::EK_ComplexElement:
case InitializedEntity::EK_ArrayElement:
case InitializedEntity::EK_Parameter:
@@ -1372,11 +1378,12 @@ void InitListChecker::CheckExplicitInitList(const InitializedEntity &Entity,
SemaRef.Diag(IList->getInit(Index)->getBeginLoc(), DK)
<< T << IList->getInit(Index)->getSourceRange();
} else {
- int initKind = T->isArrayType() ? 0 :
- T->isVectorType() ? 1 :
- T->isScalarType() ? 2 :
- T->isUnionType() ? 3 :
- 4;
+ int initKind = T->isArrayType() ? 0
+ : T->isVectorType() ? 1
+ : T->isMatrixType() ? 2
+ : T->isScalarType() ? 3
+ : T->isUnionType() ? 4
+ : 5;
unsigned DK = ExtraInitsIsError ? diag::err_excess_initializers
: diag::ext_excess_initializers;
@@ -1430,6 +1437,9 @@ void InitListChecker::CheckListElementTypes(const InitializedEntity &Entity,
} else if (DeclType->isVectorType()) {
CheckVectorType(Entity, IList, DeclType, Index,
StructuredList, StructuredIndex);
+ } else if (DeclType->isMatrixType()) {
+ CheckMatrixType(Entity, IList, DeclType, Index, StructuredList,
+ StructuredIndex);
} else if (const RecordDecl *RD = DeclType->getAsRecordDecl()) {
auto Bases =
CXXRecordDecl::base_class_const_range(CXXRecordDecl::base_class_const_iterator(),
@@ -1877,6 +1887,93 @@ void InitListChecker::CheckReferenceType(const InitializedEntity &Entity,
AggrDeductionCandidateParamTypes->push_back(DeclType);
}
+void InitListChecker::CheckMatrixType(const InitializedEntity &Entity,
+ InitListExpr *IList, QualType DeclType,
+ unsigned &Index,
+ InitListExpr *StructuredList,
+ unsigned &StructuredIndex) {
+ if (!SemaRef.getLangOpts().HLSL)
+ return;
+
+ const ConstantMatrixType *MT = DeclType->castAs<ConstantMatrixType>();
+ QualType ElemTy = MT->getElementType();
+ const unsigned Rows = MT->getNumRows();
+ const unsigned Cols = MT->getNumColumns();
+ const unsigned MaxElts = Rows * Cols;
+
+ unsigned NumEltsInit = 0;
+ InitializedEntity ElemEnt =
+ InitializedEntity::InitializeElement(SemaRef.Context, 0, Entity);
+
+ // A Matrix initalizer should be able to take scalars, vectors, and matrices.
+ auto HandleInit = [&](InitListExpr *List, unsigned &Idx) {
+ Expr *Init = List->getInit(Idx);
+ QualType ITy = Init->getType();
+
+ if (ITy->isVectorType()) {
+ const VectorType *IVT = ITy->castAs<VectorType>();
+ unsigned N = IVT->getNumElements();
+ QualType VTy =
+ ITy->isExtVectorType()
+ ? SemaRef.Context.getExtVectorType(ElemTy, N)
+ : SemaRef.Context.getVectorType(ElemTy, N, IVT->getVectorKind());
+ ElemEnt.setElementIndex(Idx);
+ CheckSubElementType(ElemEnt, List, VTy, Idx, StructuredList,
+ StructuredIndex);
+ NumEltsInit += N;
+ return;
+ }
+
+ if (ITy->isMatrixType()) {
+ const ConstantMatrixType *IMT = ITy->castAs<ConstantMatrixType>();
+ unsigned N = IMT->getNumRows() * IMT->getNumColumns();
+ QualType MTy = SemaRef.Context.getConstantMatrixType(
+ ElemTy, IMT->getNumRows(), IMT->getNumColumns());
+ ElemEnt.setElementIndex(Idx);
+ CheckSubElementType(ElemEnt, List, MTy, Idx, StructuredList,
+ StructuredIndex);
+ NumEltsInit += N;
+ return;
+ }
+
+ // Scalar element
+ ElemEnt.setElementIndex(Idx);
+ CheckSubElementType(ElemEnt, List, ElemTy, Idx, StructuredList,
+ StructuredIndex);
+ ++NumEltsInit;
+ };
+
+ // Column-major: each top-level sublist is treated as a column.
+ while (NumEltsInit < MaxElts && Index < IList->getNumInits()) {
+ Expr *Init = IList->getInit(Index);
+
+ if (auto *SubList = dyn_cast<InitListExpr>(Init)) {
+ unsigned SubIdx = 0;
+ unsigned Row = 0;
+ while (Row < Rows && SubIdx < SubList->getNumInits() &&
+ NumEltsInit < MaxElts) {
+ HandleInit(SubList, SubIdx);
+ ++Row;
+ }
+ ++Index; // advance past this column sublist
+ continue;
+ }
+
+ // Not a sublist: just consume directly.
+ HandleInit(IList, Index);
+ }
+
+ // HLSL requires exactly Rows*Cols initializers after flattening.
+ if (NumEltsInit != MaxElts) {
+ if (!VerifyOnly)
+ SemaRef.Diag(IList->getBeginLoc(),
+ diag::err_tensor_incorrect_num_elements)
+ << (NumEltsInit < MaxElts) << /*matrix*/ 1 << MaxElts << NumEltsInit
+ << /*initialization*/ 0;
+ hadError = true;
+ }
+}
+
void InitListChecker::CheckVectorType(const InitializedEntity &Entity,
InitListExpr *IList, QualType DeclType,
unsigned &Index,
@@ -2026,9 +2123,9 @@ void InitListChecker::CheckVectorType(const InitializedEntity &Entity,
if (numEltsInit != maxElements) {
if (!VerifyOnly)
SemaRef.Diag(IList->getBeginLoc(),
- diag::err_vector_incorrect_num_elements)
- << (numEltsInit < maxElements) << maxElements << numEltsInit
- << /*initialization*/ 0;
+ diag::err_tensor_incorrect_num_elements)
+ << (numEltsInit < maxElements) << /*vector*/ 0 << maxElements
+ << numEltsInit << /*initialization*/ 0;
hadError = true;
}
}
@@ -3639,6 +3736,9 @@ InitializedEntity::InitializedEntity(ASTContext &Context, unsigned Index,
} else if (const VectorType *VT = Parent.getType()->getAs<VectorType>()) {
Kind = EK_VectorElement;
Type = VT->getElementType();
+ } else if (const MatrixType *MT = Parent.getType()->getAs<MatrixType>()) {
+ Kind = EK_MatrixElement;
+ Type = MT->getElementType();
} else {
const ComplexType *CT = Parent.getType()->getAs<ComplexType>();
assert(CT && "Unexpected type");
@@ -3687,6 +3787,7 @@ DeclarationName InitializedEntity::getName() const {
case EK_Delegating:
case EK_ArrayElement:
case EK_VectorElement:
+ case EK_MatrixElement:
case EK_ComplexElement:
case EK_BlockElement:
case EK_LambdaToBlockConversionBlockElement:
@@ -3720,6 +3821,7 @@ ValueDecl *InitializedEntity::getDecl() const {
case EK_Delegating:
case EK_ArrayElement:
case EK_VectorElement:
+ case EK_MatrixElement:
case EK_ComplexElement:
case EK_BlockElement:
case EK_LambdaToBlockConversionBlockElement:
@@ -3753,6 +3855,7 @@ bool InitializedEntity::allowsNRVO() const {
case EK_Delegating:
case EK_ArrayElement:
case EK_VectorElement:
+ case EK_MatrixElement:
case EK_ComplexElement:
case EK_BlockElement:
case EK_LambdaToBlockConversionBlockElement:
@@ -3792,6 +3895,9 @@ unsigned InitializedEntity::dumpImpl(raw_ostream &OS) const {
case EK_Delegating: OS << "Delegating"; break;
case EK_ArrayElement: OS << "ArrayElement " << Index; break;
case EK_VectorElement: OS << "VectorElement " << Index; break;
+ case EK_MatrixElement:
+ OS << "MatrixElement " << Index;
+ break;
case EK_ComplexElement: OS << "ComplexElement " << Index; break;
case EK_BlockElement: OS << "Block"; break;
case EK_LambdaToBlockConversionBlockElement:
@@ -6847,6 +6953,67 @@ void InitializationSequence::InitializeFrom(Sema &S,
return;
}
+ if (S.getLangOpts().HLSL && DestType->isMatrixType() &&
+ (SourceType.isNull() ||
+ !Context.hasSameUnqualifiedType(SourceType, DestType))) {
+
+ llvm::SmallVector<Expr *> InitArgs;
+
+ for (Expr *Arg : Args) {
+ QualType AT = Arg->getType();
+
+ // Expand matrix arguments element-by-element (col-major).
+ if (AT->isMatrixType()) {
+ const auto *MTy = AT->castAs<ConstantMatrixType>();
+ unsigned Rows = MTy->getNumRows();
+ unsigned Cols = MTy->getNumColumns();
+ QualType ElemTy = MTy->getElementType();
+
+ for (unsigned c = 0; c < Cols; ++c) {
+ for (unsigned r = 0; r < Rows; ++r) {
+ // row index literal
+ Expr *RowIdx = IntegerLiteral::Create(
+ Context, llvm::APInt(Context.getIntWidth(Context.IntTy), r),
+ Context.IntTy, Arg->getBeginLoc());
+ // column index literal
+ Expr *ColIdx = IntegerLiteral::Create(
+ Context, llvm::APInt(Context.getIntWidth(Context.IntTy), c),
+ Context.IntTy, Arg->getBeginLoc());
+
+ InitArgs.emplace_back(new (Context) MatrixSubscriptExpr(
+ Arg, RowIdx, ColIdx, ElemTy, Arg->getEndLoc()));
+ }
+ }
+
+ // Keep your vector expansion, in case vectors appear in the argument
+ // list.
+ } else if (AT->isExtVectorType()) {
+ const auto *VTy = AT->castAs<ExtVectorType>();
+ unsigned Elm = VTy->getNumElements();
+ for (unsigned Idx = 0; Idx < Elm; ++Idx) {
+ InitArgs.emplace_back(new (Context) ArraySubscriptExpr(
+ Arg,
+ IntegerLiteral::Create(
+ Context, llvm::APInt(Context.getIntWidth(Context.IntTy), Idx),
+ Context.IntTy, Arg->getBeginLoc()),
+ VTy->getElementType(), Arg->getValueKind(), Arg->getObjectKind(),
+ Arg->getEndLoc()));
+ }
+
+ } else {
+ // Scalar or other: forward as-is
+ InitArgs.emplace_back(Arg);
+ }
+ }
+
+ InitListExpr *ILE = new (Context) InitListExpr(
+ S.getASTContext(), SourceLocation(), InitArgs, SourceLocation());
+
+ Args[0] = ILE;
+ AddListInitializationStep(DestType);
+ return;
+ }
+
// The remaining cases all need a source type.
if (Args.size() > 1) {
SetFailed(FK_TooManyInitsForScalar);
@@ -6999,6 +7166,7 @@ static AssignmentAction getAssignmentAction(const InitializedEntity &Entity,
case InitializedEntity::EK_Binding:
case InitializedEntity::EK_ArrayElement:
case InitializedEntity::EK_VectorElement:
+ case InitializedEntity::EK_MatrixElement:
case InitializedEntity::EK_ComplexElement:
case InitializedEntity::EK_BlockElement:
case InitializedEntity::EK_LambdaToBlockConversionBlockElement:
@@ -7024,6 +7192,7 @@ static bool shouldBindAsTemporary(const InitializedEntity &Entity) {
case InitializedEntity::EK_Base:
case InitializedEntity::EK_Delegating:
case InitializedEntity::EK_VectorElement:
+ case InitializedEntity::EK_MatrixElement:
case InitializedEntity::EK_ComplexElement:
case InitializedEntity::EK_Exception:
case InitializedEntity::EK_BlockElement:
@@ -7054,6 +7223,7 @@ static bool shouldDestroyEntity(const InitializedEntity &Entity) {
case InitializedEntity::EK_Base:
case InitializedEntity::EK_Delegating:
case InitializedEntity::EK_VectorElement:
+ case InitializedEntity::EK_MatrixElement:
case InitializedEntity::EK_ComplexElement:
case InitializedEntity::EK_BlockElement:
case InitializedEntity::EK_LambdaToBlockConversionBlockElement:
@@ -7107,6 +7277,7 @@ static SourceLocation getInitializationLoc(const InitializedEntity &Entity,
case InitializedEntity::EK_Base:
case InitializedEntity::EK_Delegating:
case InitializedEntity::EK_VectorElement:
+ case InitializedEntity::EK_MatrixElement:
case InitializedEntity::EK_ComplexElement:
case InitializedEntity::EK_BlockElement:
case InitializedEntity::EK_LambdaToBlockConversionBlockElement:
@@ -7858,9 +8029,11 @@ ExprResult InitializationSequence::Perform(Sema &S,
// HLSL allows vector initialization to function like list initialization, but
// use the syntax of a C++-like constructor.
- bool IsHLSLVectorInit = S.getLangOpts().HLSL && DestType->isExtVectorType() &&
- isa<InitListExpr>(Args[0]);
- (void)IsHLSLVectorInit;
+ bool IsHLSLVectorOrMatrixInit =
+ S.getLangOpts().HLSL &&
+ (DestType->isExtVectorType() || DestType->isMatrixType()) &&
+ isa<InitListExpr>(Args[0]);
+ (void)IsHLSLVectorOrMatrixInit;
// For initialization steps that start with a single initializer,
// grab the only argument out the Args and place it into the "current"
@@ -7899,7 +8072,7 @@ ExprResult InitializationSequence::Perform(Sema &S,
case SK_StdInitializerList:
case SK_OCLSamplerInit:
case SK_OCLZeroOpaqueType: {
- assert(Args.size() == 1 || IsHLSLVectorInit);
+ assert(Args.size() == 1 || IsHLSLVectorOrMatrixInit);
CurInit = Args[0];
if (!CurInit.get()) return ExprError();
break;
diff --git a/clang/test/AST/HLSL/matrix-constructors.hlsl b/clang/test/AST/HLSL/matrix-constructors.hlsl
new file mode 100644
index 0000000000000..faee0162a314b
--- /dev/null
+++ b/clang/test/AST/HLSL/matrix-constructors.hlsl
@@ -0,0 +1,338 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -ast-dump -o - %s | FileCheck %s
+
+typedef float float2x1 __attribute__((matrix_type(2,1)));
+typedef float float2x3 __attribute__((matrix_type(2,3)));
+typedef float float2x2 __attribute__((matrix_type(2,2)));
+typedef float float2 __attribute__((ext_vector_type(2)));
+typedef float float4 __attribute__((ext_vector_type(4)));
+
+[numthreads(1,1,1)]
+void ok() {
+
+ // CHECK: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:36> col:12 A 'float2x3':'matrix<float, 2, 3>' cinit
+ // CHECK-NEXT: CXXFunctionalCastExpr 0x{{[0-9a-fA-F]+}} <col:16, col:36> 'float2x3':'matrix<float, 2, 3>' functional cast to float2x3 <NoOp>
+ // CHECK-NEXT: InitListExpr 0x{{[0-9a-fA-F]+}} <col:25, col:35> 'float2x3':'matrix<float, 2, 3>'
+ // CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:25> 'float' <IntegralToFloating>
+ // CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:25> 'int' 1
+ // CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:27> 'float' <IntegralToFloating>
+ // CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:27> 'int' 2
+ // CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:29> 'float' <IntegralToFloating>
+ // CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:29> 'int' 3
+ // CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:31> 'float' <IntegralToFloating>
+ // CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:31> 'int' 4
+ // CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:33> 'float' <IntegralToFloating>
+ // CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:33> ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/160960
More information about the cfe-commits
mailing list