[llvm-branch-commits] [clang] [clang][bytecode][HLSL][Matrix] Support `ConstantMatrixType` and more HLSL casts in the new constant interpreter for basic matrix constexpr evaluation in HLSL (PR #183424)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Feb 25 16:55:56 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-clang
Author: Deric C. (Icohedron)
<details>
<summary>Changes</summary>
Fixes #<!-- -->182963
This PR is an extension of #<!-- -->178762 and is to be merged immediately after it.
This PR adds support for `ConstantMatrixType` and the HLSL casts `CK_HLSLArrayRValue`, `CK_HLSLMatrixTruncation`, `CK_HLSLAggregateSplatCast`, and `CK_HLSLElementwiseCast` to the bytecode constexpr evaluator.
The implementations of CK_HLSLAggregateSplatCast and CK_HLSLElementwiseCast are incomplete, as they still need to support struct and array types to enable use of the experimental new constant interpreter on other existing HLSL constexpr tests.
The completion of the implementations of these casts will be tracked in a separate issue and implemented in a separate PR.
Assisted-by: claude-opus-4.6
---
Full diff: https://github.com/llvm/llvm-project/pull/183424.diff
5 Files Affected:
- (modified) clang/lib/AST/ByteCode/Compiler.cpp (+207)
- (modified) clang/lib/AST/ByteCode/Compiler.h (+5)
- (modified) clang/lib/AST/ByteCode/Pointer.cpp (+18)
- (modified) clang/lib/AST/ByteCode/Program.cpp (+11)
- (modified) clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl (+2)
``````````diff
diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp
index 15e65a4d96581..87982e67dcb51 100644
--- a/clang/lib/AST/ByteCode/Compiler.cpp
+++ b/clang/lib/AST/ByteCode/Compiler.cpp
@@ -811,6 +811,180 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
case CK_LValueBitCast:
return this->emitInvalidCast(CastKind::ReinterpretLike, /*Fatal=*/true, CE);
+ case CK_HLSLArrayRValue: {
+ // Non-decaying array rvalue cast - creates an rvalue copy of an lvalue
+ // array, similar to LValueToRValue for composite types.
+ if (!Initializing) {
+ UnsignedOrNone LocalIndex = allocateLocal(CE);
+ if (!LocalIndex)
+ return false;
+ if (!this->emitGetPtrLocal(*LocalIndex, CE))
+ return false;
+ }
+ if (!this->visit(SubExpr))
+ return false;
+ return this->emitMemcpy(CE);
+ }
+
+ case CK_HLSLMatrixTruncation: {
+ assert(SubExpr->getType()->isConstantMatrixType());
+ if (OptPrimType ResultT = classify(CE)) {
+ assert(!DiscardResult);
+ // Result must be either a float or integer. Take the first element.
+ if (!this->visit(SubExpr))
+ return false;
+ return this->emitArrayElemPop(*ResultT, 0, CE);
+ }
+ // Otherwise, this truncates to a a constant matrix type.
+ assert(CE->getType()->isConstantMatrixType());
+
+ if (!Initializing) {
+ UnsignedOrNone LocalIndex = allocateTemporary(CE);
+ if (!LocalIndex)
+ return false;
+ if (!this->emitGetPtrLocal(*LocalIndex, CE))
+ return false;
+ }
+ unsigned ToSize =
+ CE->getType()->getAs<ConstantMatrixType>()->getNumElementsFlattened();
+ if (!this->visit(SubExpr))
+ return false;
+ return this->emitCopyArray(classifyMatrixElementType(SubExpr->getType()), 0,
+ 0, ToSize, CE);
+ }
+
+ case CK_HLSLAggregateSplatCast: {
+ // Aggregate splat cast: convert a scalar value to one of an aggregate type.
+ // TODO: Aggregate splat to struct and array types
+ assert(canClassify(SubExpr->getType()));
+
+ unsigned NumElts;
+ PrimType DestElemT;
+ QualType DestElemType;
+ if (const auto *VT = CE->getType()->getAs<VectorType>()) {
+ NumElts = VT->getNumElements();
+ DestElemType = VT->getElementType();
+ } else if (const auto *MT =
+ CE->getType()->getAs<ConstantMatrixType>()) {
+ NumElts = MT->getNumElementsFlattened();
+ DestElemType = MT->getElementType();
+ } else {
+ return false;
+ }
+ DestElemT = classifyPrim(DestElemType);
+
+ if (!Initializing) {
+ UnsignedOrNone LocalIndex = allocateLocal(CE);
+ if (!LocalIndex)
+ return false;
+ if (!this->emitGetPtrLocal(*LocalIndex, CE))
+ return false;
+ }
+
+ PrimType SrcElemT = classifyPrim(SubExpr->getType());
+ unsigned SrcOffset =
+ allocateLocalPrimitive(SubExpr, DestElemT, /*IsConst=*/true);
+
+ if (!this->visit(SubExpr))
+ return false;
+ if (classifyPrim(SubExpr) == PT_Ptr && !this->emitLoadPop(SrcElemT, CE))
+ return false;
+ if (SrcElemT != DestElemT) {
+ if (!this->emitPrimCast(SrcElemT, DestElemT, DestElemType, CE))
+ return false;
+ }
+ if (!this->emitSetLocal(DestElemT, SrcOffset, CE))
+ return false;
+
+ for (unsigned I = 0; I != NumElts; ++I) {
+ if (!this->emitGetLocal(DestElemT, SrcOffset, CE))
+ return false;
+ if (!this->emitInitElem(DestElemT, I, CE))
+ return false;
+ }
+ return true;
+ }
+
+ case CK_HLSLElementwiseCast: {
+ // Elementwise cast: flatten source elements of one aggregate type and store
+ // to a destination aggregate type of the same or fewer number of elements.
+ // TODO: Elementwise cast to structs, nested arrays, and arrays of composite
+ // types
+ QualType SrcType = SubExpr->getType();
+ QualType DestType = CE->getType();
+
+ unsigned SrcNumElts;
+ PrimType SrcElemT;
+ if (const auto *VT = SrcType->getAs<VectorType>()) {
+ SrcNumElts = VT->getNumElements();
+ SrcElemT = classifyPrim(VT->getElementType());
+ } else if (const auto *MT = SrcType->getAs<ConstantMatrixType>()) {
+ SrcNumElts = MT->getNumElementsFlattened();
+ SrcElemT = classifyPrim(MT->getElementType());
+ } else if (const auto *AT = SrcType->getAsArrayTypeUnsafe()) {
+ if (const auto *CAT = dyn_cast<ConstantArrayType>(AT)) {
+ SrcNumElts = CAT->getZExtSize();
+ SrcElemT = classifyPrim(CAT->getElementType());
+ } else {
+ return false;
+ }
+ } else {
+ return false;
+ }
+
+ unsigned DestNumElts;
+ PrimType DestElemT;
+ QualType DestElemType;
+ if (const auto *VT = DestType->getAs<VectorType>()) {
+ DestNumElts = VT->getNumElements();
+ DestElemType = VT->getElementType();
+ } else if (const auto *MT = DestType->getAs<ConstantMatrixType>()) {
+ DestNumElts = MT->getNumElementsFlattened();
+ DestElemType = MT->getElementType();
+ } else if (const auto *AT = DestType->getAsArrayTypeUnsafe()) {
+ if (const auto *CAT = dyn_cast<ConstantArrayType>(AT)) {
+ DestNumElts = CAT->getZExtSize();
+ DestElemType = CAT->getElementType();
+ } else {
+ return false;
+ }
+ } else {
+ return false;
+ }
+ DestElemT = classifyPrim(DestElemType);
+
+ if (!Initializing) {
+ UnsignedOrNone LocalIndex =
+ classify(DestType) ? allocateLocal(CE) : allocateTemporary(CE);
+ if (!LocalIndex)
+ return false;
+ if (!this->emitGetPtrLocal(*LocalIndex, CE))
+ return false;
+ }
+
+ unsigned SrcOffset =
+ allocateLocalPrimitive(SubExpr, PT_Ptr, /*IsConst=*/true);
+ if (!this->visit(SubExpr))
+ return false;
+ if (!this->emitSetLocal(PT_Ptr, SrcOffset, CE))
+ return false;
+
+ unsigned NumElts = std::min(SrcNumElts, DestNumElts);
+ for (unsigned I = 0; I != NumElts; ++I) {
+ if (!this->emitGetLocal(PT_Ptr, SrcOffset, CE))
+ return false;
+ if (!this->emitArrayElemPop(SrcElemT, I, CE))
+ return false;
+ if (SrcElemT != DestElemT) {
+ if (!this->emitPrimCast(SrcElemT, DestElemT, DestElemType, CE))
+ return false;
+ }
+ if (!this->emitInitElem(DestElemT, I, CE))
+ return false;
+ }
+ return true;
+ }
+
default:
return this->emitInvalid(CE);
}
@@ -1813,6 +1987,20 @@ bool Compiler<Emitter>::VisitImplicitValueInitExpr(
return true;
}
+ if (const auto *MT = E->getType()->getAs<ConstantMatrixType>()) {
+ unsigned NumElts = MT->getNumElementsFlattened();
+ QualType ElemQT = MT->getElementType();
+ PrimType ElemT = classifyPrim(ElemQT);
+
+ for (unsigned I = 0; I < NumElts; ++I) {
+ if (!this->visitZeroInitializer(ElemT, ElemQT, E))
+ return false;
+ if (!this->emitInitElem(ElemT, I, E))
+ return false;
+ }
+ return true;
+ }
+
return false;
}
@@ -2129,6 +2317,25 @@ bool Compiler<Emitter>::visitInitList(ArrayRef<const Expr *> Inits,
return true;
}
+ if (const auto *MT = QT->getAs<ConstantMatrixType>()) {
+ unsigned NumElts = MT->getNumElementsFlattened();
+ assert(Inits.size() == NumElts);
+
+ QualType ElemQT = MT->getElementType();
+ PrimType ElemT = classifyPrim(ElemQT);
+
+ // InitListExpr elements are in column-major order.
+ // Store in row-major order to match APValue convention.
+ for (unsigned I = 0; I < NumElts; ++I) {
+ if (!this->visit(Inits[I]))
+ return false;
+ if (!this->emitInitElem(ElemT,
+ MT->mapColumnMajorToRowMajorFlattenedIndex(I), E))
+ return false;
+ }
+ return true;
+ }
+
return false;
}
diff --git a/clang/lib/AST/ByteCode/Compiler.h b/clang/lib/AST/ByteCode/Compiler.h
index 1bd15c3d79563..74ded47e88792 100644
--- a/clang/lib/AST/ByteCode/Compiler.h
+++ b/clang/lib/AST/ByteCode/Compiler.h
@@ -406,6 +406,11 @@ class Compiler : public ConstStmtVisitor<Compiler<Emitter>, bool>,
return *this->classify(T->getAs<VectorType>()->getElementType());
}
+ PrimType classifyMatrixElementType(QualType T) const {
+ assert(T->isMatrixType());
+ return *this->classify(T->getAs<MatrixType>()->getElementType());
+ }
+
bool emitComplexReal(const Expr *SubExpr);
bool emitComplexBoolCast(const Expr *E);
bool emitComplexComparison(const Expr *LHS, const Expr *RHS,
diff --git a/clang/lib/AST/ByteCode/Pointer.cpp b/clang/lib/AST/ByteCode/Pointer.cpp
index e237013f4199c..a569a4221cf2f 100644
--- a/clang/lib/AST/ByteCode/Pointer.cpp
+++ b/clang/lib/AST/ByteCode/Pointer.cpp
@@ -934,6 +934,24 @@ std::optional<APValue> Pointer::toRValue(const Context &Ctx,
return true;
}
+ // Constant Matrix types.
+ if (const auto *MT = Ty->getAs<ConstantMatrixType>()) {
+ assert(Ptr.getFieldDesc()->isPrimitiveArray());
+ QualType ElemTy = MT->getElementType();
+ PrimType ElemT = *Ctx.classify(ElemTy);
+ unsigned NumElts = MT->getNumElementsFlattened();
+
+ SmallVector<APValue> Values;
+ Values.reserve(NumElts);
+ for (unsigned I = 0; I != NumElts; ++I) {
+ TYPE_SWITCH(ElemT,
+ { Values.push_back(Ptr.elem<T>(I).toAPValue(ASTCtx)); });
+ }
+
+ R = APValue(Values.data(), MT->getNumRows(), MT->getNumColumns());
+ return true;
+ }
+
llvm_unreachable("invalid value to return");
};
diff --git a/clang/lib/AST/ByteCode/Program.cpp b/clang/lib/AST/ByteCode/Program.cpp
index 76fec63a8920d..8876ded409415 100644
--- a/clang/lib/AST/ByteCode/Program.cpp
+++ b/clang/lib/AST/ByteCode/Program.cpp
@@ -474,5 +474,16 @@ Descriptor *Program::createDescriptor(const DeclTy &D, const Type *Ty,
IsTemporary, IsMutable);
}
+ // Same with constant matrix types.
+ if (const auto *MT = Ty->getAs<ConstantMatrixType>()) {
+ OptPrimType ElemTy = Ctx.classify(MT->getElementType());
+ if (!ElemTy)
+ return nullptr;
+
+ return allocateDescriptor(D, *ElemTy, MDSize,
+ MT->getNumElementsFlattened(), IsConst,
+ IsTemporary, IsMutable);
+ }
+
return nullptr;
}
diff --git a/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl b/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl
index 64220980d9edc..608af75bff4bb 100644
--- a/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl
+++ b/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl
@@ -1,5 +1,7 @@
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -std=hlsl202x -fmatrix-memory-layout=column-major -verify %s
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -std=hlsl202x -fmatrix-memory-layout=row-major -verify %s
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -std=hlsl202x -fmatrix-memory-layout=column-major -fexperimental-new-constant-interpreter -verify %s
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -std=hlsl202x -fmatrix-memory-layout=row-major -fexperimental-new-constant-interpreter -verify %s
// expected-no-diagnostics
``````````
</details>
https://github.com/llvm/llvm-project/pull/183424
More information about the llvm-branch-commits
mailing list