[llvm-branch-commits] [clang] [clang][bytecode][HLSL][Matrix] Support `ConstantMatrixType` and more HLSL casts in the new constant interpreter for basic matrix constexpr evaluation in HLSL (PR #183424)
Deric C. via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Feb 26 09:04:57 PST 2026
https://github.com/Icohedron updated https://github.com/llvm/llvm-project/pull/183424
>From c33ea3969b9f527d18a37932751c625d9b3f71d1 Mon Sep 17 00:00:00 2001
From: Deric Cheung <cheung.deric at gmail.com>
Date: Wed, 25 Feb 2026 16:13:16 -0800
Subject: [PATCH 1/7] Support ConstantMatrixTypes and HLSL casts to bytecode
constexpr evaluator
This commit adds support for ConstantMatrixType and the HLSL casts
CK_HLSLArrayRValue, CK_HLSLMatrixTruncation, CK_HLSLAggregateSplatCast,
and CK_HLSLElementwiseCast to the bytecode constexpr evaluator.
The implementations of CK_HLSLAggregateSplatCast and
CK_HLSLElementwiseCast are incomplete, as they still need to support
struct and array types.
Assisted-by: claude-opus-4.6
---
clang/lib/AST/ByteCode/Compiler.cpp | 207 ++++++++++++++++++
clang/lib/AST/ByteCode/Compiler.h | 5 +
clang/lib/AST/ByteCode/Pointer.cpp | 18 ++
clang/lib/AST/ByteCode/Program.cpp | 11 +
.../BuiltinMatrix/MatrixConstantExpr.hlsl | 2 +
5 files changed, 243 insertions(+)
diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp
index 15e65a4d96581..87982e67dcb51 100644
--- a/clang/lib/AST/ByteCode/Compiler.cpp
+++ b/clang/lib/AST/ByteCode/Compiler.cpp
@@ -811,6 +811,180 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
case CK_LValueBitCast:
return this->emitInvalidCast(CastKind::ReinterpretLike, /*Fatal=*/true, CE);
+ case CK_HLSLArrayRValue: {
+ // Non-decaying array rvalue cast - creates an rvalue copy of an lvalue
+ // array, similar to LValueToRValue for composite types.
+ if (!Initializing) {
+ UnsignedOrNone LocalIndex = allocateLocal(CE);
+ if (!LocalIndex)
+ return false;
+ if (!this->emitGetPtrLocal(*LocalIndex, CE))
+ return false;
+ }
+ if (!this->visit(SubExpr))
+ return false;
+ return this->emitMemcpy(CE);
+ }
+
+ case CK_HLSLMatrixTruncation: {
+ assert(SubExpr->getType()->isConstantMatrixType());
+ if (OptPrimType ResultT = classify(CE)) {
+ assert(!DiscardResult);
+ // Result must be either a float or integer. Take the first element.
+ if (!this->visit(SubExpr))
+ return false;
+ return this->emitArrayElemPop(*ResultT, 0, CE);
+ }
+ // Otherwise, this truncates to a a constant matrix type.
+ assert(CE->getType()->isConstantMatrixType());
+
+ if (!Initializing) {
+ UnsignedOrNone LocalIndex = allocateTemporary(CE);
+ if (!LocalIndex)
+ return false;
+ if (!this->emitGetPtrLocal(*LocalIndex, CE))
+ return false;
+ }
+ unsigned ToSize =
+ CE->getType()->getAs<ConstantMatrixType>()->getNumElementsFlattened();
+ if (!this->visit(SubExpr))
+ return false;
+ return this->emitCopyArray(classifyMatrixElementType(SubExpr->getType()), 0,
+ 0, ToSize, CE);
+ }
+
+ case CK_HLSLAggregateSplatCast: {
+ // Aggregate splat cast: convert a scalar value to one of an aggregate type.
+ // TODO: Aggregate splat to struct and array types
+ assert(canClassify(SubExpr->getType()));
+
+ unsigned NumElts;
+ PrimType DestElemT;
+ QualType DestElemType;
+ if (const auto *VT = CE->getType()->getAs<VectorType>()) {
+ NumElts = VT->getNumElements();
+ DestElemType = VT->getElementType();
+ } else if (const auto *MT =
+ CE->getType()->getAs<ConstantMatrixType>()) {
+ NumElts = MT->getNumElementsFlattened();
+ DestElemType = MT->getElementType();
+ } else {
+ return false;
+ }
+ DestElemT = classifyPrim(DestElemType);
+
+ if (!Initializing) {
+ UnsignedOrNone LocalIndex = allocateLocal(CE);
+ if (!LocalIndex)
+ return false;
+ if (!this->emitGetPtrLocal(*LocalIndex, CE))
+ return false;
+ }
+
+ PrimType SrcElemT = classifyPrim(SubExpr->getType());
+ unsigned SrcOffset =
+ allocateLocalPrimitive(SubExpr, DestElemT, /*IsConst=*/true);
+
+ if (!this->visit(SubExpr))
+ return false;
+ if (classifyPrim(SubExpr) == PT_Ptr && !this->emitLoadPop(SrcElemT, CE))
+ return false;
+ if (SrcElemT != DestElemT) {
+ if (!this->emitPrimCast(SrcElemT, DestElemT, DestElemType, CE))
+ return false;
+ }
+ if (!this->emitSetLocal(DestElemT, SrcOffset, CE))
+ return false;
+
+ for (unsigned I = 0; I != NumElts; ++I) {
+ if (!this->emitGetLocal(DestElemT, SrcOffset, CE))
+ return false;
+ if (!this->emitInitElem(DestElemT, I, CE))
+ return false;
+ }
+ return true;
+ }
+
+ case CK_HLSLElementwiseCast: {
+ // Elementwise cast: flatten source elements of one aggregate type and store
+ // to a destination aggregate type of the same or fewer number of elements.
+ // TODO: Elementwise cast to structs, nested arrays, and arrays of composite
+ // types
+ QualType SrcType = SubExpr->getType();
+ QualType DestType = CE->getType();
+
+ unsigned SrcNumElts;
+ PrimType SrcElemT;
+ if (const auto *VT = SrcType->getAs<VectorType>()) {
+ SrcNumElts = VT->getNumElements();
+ SrcElemT = classifyPrim(VT->getElementType());
+ } else if (const auto *MT = SrcType->getAs<ConstantMatrixType>()) {
+ SrcNumElts = MT->getNumElementsFlattened();
+ SrcElemT = classifyPrim(MT->getElementType());
+ } else if (const auto *AT = SrcType->getAsArrayTypeUnsafe()) {
+ if (const auto *CAT = dyn_cast<ConstantArrayType>(AT)) {
+ SrcNumElts = CAT->getZExtSize();
+ SrcElemT = classifyPrim(CAT->getElementType());
+ } else {
+ return false;
+ }
+ } else {
+ return false;
+ }
+
+ unsigned DestNumElts;
+ PrimType DestElemT;
+ QualType DestElemType;
+ if (const auto *VT = DestType->getAs<VectorType>()) {
+ DestNumElts = VT->getNumElements();
+ DestElemType = VT->getElementType();
+ } else if (const auto *MT = DestType->getAs<ConstantMatrixType>()) {
+ DestNumElts = MT->getNumElementsFlattened();
+ DestElemType = MT->getElementType();
+ } else if (const auto *AT = DestType->getAsArrayTypeUnsafe()) {
+ if (const auto *CAT = dyn_cast<ConstantArrayType>(AT)) {
+ DestNumElts = CAT->getZExtSize();
+ DestElemType = CAT->getElementType();
+ } else {
+ return false;
+ }
+ } else {
+ return false;
+ }
+ DestElemT = classifyPrim(DestElemType);
+
+ if (!Initializing) {
+ UnsignedOrNone LocalIndex =
+ classify(DestType) ? allocateLocal(CE) : allocateTemporary(CE);
+ if (!LocalIndex)
+ return false;
+ if (!this->emitGetPtrLocal(*LocalIndex, CE))
+ return false;
+ }
+
+ unsigned SrcOffset =
+ allocateLocalPrimitive(SubExpr, PT_Ptr, /*IsConst=*/true);
+ if (!this->visit(SubExpr))
+ return false;
+ if (!this->emitSetLocal(PT_Ptr, SrcOffset, CE))
+ return false;
+
+ unsigned NumElts = std::min(SrcNumElts, DestNumElts);
+ for (unsigned I = 0; I != NumElts; ++I) {
+ if (!this->emitGetLocal(PT_Ptr, SrcOffset, CE))
+ return false;
+ if (!this->emitArrayElemPop(SrcElemT, I, CE))
+ return false;
+ if (SrcElemT != DestElemT) {
+ if (!this->emitPrimCast(SrcElemT, DestElemT, DestElemType, CE))
+ return false;
+ }
+ if (!this->emitInitElem(DestElemT, I, CE))
+ return false;
+ }
+ return true;
+ }
+
default:
return this->emitInvalid(CE);
}
@@ -1813,6 +1987,20 @@ bool Compiler<Emitter>::VisitImplicitValueInitExpr(
return true;
}
+ if (const auto *MT = E->getType()->getAs<ConstantMatrixType>()) {
+ unsigned NumElts = MT->getNumElementsFlattened();
+ QualType ElemQT = MT->getElementType();
+ PrimType ElemT = classifyPrim(ElemQT);
+
+ for (unsigned I = 0; I < NumElts; ++I) {
+ if (!this->visitZeroInitializer(ElemT, ElemQT, E))
+ return false;
+ if (!this->emitInitElem(ElemT, I, E))
+ return false;
+ }
+ return true;
+ }
+
return false;
}
@@ -2129,6 +2317,25 @@ bool Compiler<Emitter>::visitInitList(ArrayRef<const Expr *> Inits,
return true;
}
+ if (const auto *MT = QT->getAs<ConstantMatrixType>()) {
+ unsigned NumElts = MT->getNumElementsFlattened();
+ assert(Inits.size() == NumElts);
+
+ QualType ElemQT = MT->getElementType();
+ PrimType ElemT = classifyPrim(ElemQT);
+
+ // InitListExpr elements are in column-major order.
+ // Store in row-major order to match APValue convention.
+ for (unsigned I = 0; I < NumElts; ++I) {
+ if (!this->visit(Inits[I]))
+ return false;
+ if (!this->emitInitElem(ElemT,
+ MT->mapColumnMajorToRowMajorFlattenedIndex(I), E))
+ return false;
+ }
+ return true;
+ }
+
return false;
}
diff --git a/clang/lib/AST/ByteCode/Compiler.h b/clang/lib/AST/ByteCode/Compiler.h
index 1bd15c3d79563..74ded47e88792 100644
--- a/clang/lib/AST/ByteCode/Compiler.h
+++ b/clang/lib/AST/ByteCode/Compiler.h
@@ -406,6 +406,11 @@ class Compiler : public ConstStmtVisitor<Compiler<Emitter>, bool>,
return *this->classify(T->getAs<VectorType>()->getElementType());
}
+ PrimType classifyMatrixElementType(QualType T) const {
+ assert(T->isMatrixType());
+ return *this->classify(T->getAs<MatrixType>()->getElementType());
+ }
+
bool emitComplexReal(const Expr *SubExpr);
bool emitComplexBoolCast(const Expr *E);
bool emitComplexComparison(const Expr *LHS, const Expr *RHS,
diff --git a/clang/lib/AST/ByteCode/Pointer.cpp b/clang/lib/AST/ByteCode/Pointer.cpp
index e237013f4199c..a569a4221cf2f 100644
--- a/clang/lib/AST/ByteCode/Pointer.cpp
+++ b/clang/lib/AST/ByteCode/Pointer.cpp
@@ -934,6 +934,24 @@ std::optional<APValue> Pointer::toRValue(const Context &Ctx,
return true;
}
+ // Constant Matrix types.
+ if (const auto *MT = Ty->getAs<ConstantMatrixType>()) {
+ assert(Ptr.getFieldDesc()->isPrimitiveArray());
+ QualType ElemTy = MT->getElementType();
+ PrimType ElemT = *Ctx.classify(ElemTy);
+ unsigned NumElts = MT->getNumElementsFlattened();
+
+ SmallVector<APValue> Values;
+ Values.reserve(NumElts);
+ for (unsigned I = 0; I != NumElts; ++I) {
+ TYPE_SWITCH(ElemT,
+ { Values.push_back(Ptr.elem<T>(I).toAPValue(ASTCtx)); });
+ }
+
+ R = APValue(Values.data(), MT->getNumRows(), MT->getNumColumns());
+ return true;
+ }
+
llvm_unreachable("invalid value to return");
};
diff --git a/clang/lib/AST/ByteCode/Program.cpp b/clang/lib/AST/ByteCode/Program.cpp
index 76fec63a8920d..8876ded409415 100644
--- a/clang/lib/AST/ByteCode/Program.cpp
+++ b/clang/lib/AST/ByteCode/Program.cpp
@@ -474,5 +474,16 @@ Descriptor *Program::createDescriptor(const DeclTy &D, const Type *Ty,
IsTemporary, IsMutable);
}
+ // Same with constant matrix types.
+ if (const auto *MT = Ty->getAs<ConstantMatrixType>()) {
+ OptPrimType ElemTy = Ctx.classify(MT->getElementType());
+ if (!ElemTy)
+ return nullptr;
+
+ return allocateDescriptor(D, *ElemTy, MDSize,
+ MT->getNumElementsFlattened(), IsConst,
+ IsTemporary, IsMutable);
+ }
+
return nullptr;
}
diff --git a/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl b/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl
index 64220980d9edc..608af75bff4bb 100644
--- a/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl
+++ b/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl
@@ -1,5 +1,7 @@
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -std=hlsl202x -fmatrix-memory-layout=column-major -verify %s
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -std=hlsl202x -fmatrix-memory-layout=row-major -verify %s
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -std=hlsl202x -fmatrix-memory-layout=column-major -fexperimental-new-constant-interpreter -verify %s
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -std=hlsl202x -fmatrix-memory-layout=row-major -fexperimental-new-constant-interpreter -verify %s
// expected-no-diagnostics
>From cd61ecb70cf65b5df44154c418179db137399645 Mon Sep 17 00:00:00 2001
From: Deric Cheung <cheung.deric at gmail.com>
Date: Wed, 25 Feb 2026 16:59:02 -0800
Subject: [PATCH 2/7] Apply clang-format
---
clang/lib/AST/ByteCode/Compiler.cpp | 3 +--
clang/lib/AST/ByteCode/Program.cpp | 5 ++---
2 files changed, 3 insertions(+), 5 deletions(-)
diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp
index 87982e67dcb51..6f5ed3b7dd107 100644
--- a/clang/lib/AST/ByteCode/Compiler.cpp
+++ b/clang/lib/AST/ByteCode/Compiler.cpp
@@ -864,8 +864,7 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
if (const auto *VT = CE->getType()->getAs<VectorType>()) {
NumElts = VT->getNumElements();
DestElemType = VT->getElementType();
- } else if (const auto *MT =
- CE->getType()->getAs<ConstantMatrixType>()) {
+ } else if (const auto *MT = CE->getType()->getAs<ConstantMatrixType>()) {
NumElts = MT->getNumElementsFlattened();
DestElemType = MT->getElementType();
} else {
diff --git a/clang/lib/AST/ByteCode/Program.cpp b/clang/lib/AST/ByteCode/Program.cpp
index 8876ded409415..a2cba3da27675 100644
--- a/clang/lib/AST/ByteCode/Program.cpp
+++ b/clang/lib/AST/ByteCode/Program.cpp
@@ -480,9 +480,8 @@ Descriptor *Program::createDescriptor(const DeclTy &D, const Type *Ty,
if (!ElemTy)
return nullptr;
- return allocateDescriptor(D, *ElemTy, MDSize,
- MT->getNumElementsFlattened(), IsConst,
- IsTemporary, IsMutable);
+ return allocateDescriptor(D, *ElemTy, MDSize, MT->getNumElementsFlattened(),
+ IsConst, IsTemporary, IsMutable);
}
return nullptr;
>From 7bc9164b5d3a2977f55136844f7a39c9c952f1cd Mon Sep 17 00:00:00 2001
From: Deric Cheung <cheung.deric at gmail.com>
Date: Wed, 25 Feb 2026 22:02:56 -0800
Subject: [PATCH 3/7] Rename NumElts to NumElems
---
clang/lib/AST/ByteCode/Compiler.cpp | 38 ++++++++++++++---------------
clang/lib/AST/ByteCode/Pointer.cpp | 6 ++---
2 files changed, 22 insertions(+), 22 deletions(-)
diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp
index 6f5ed3b7dd107..32ec16570a5b6 100644
--- a/clang/lib/AST/ByteCode/Compiler.cpp
+++ b/clang/lib/AST/ByteCode/Compiler.cpp
@@ -858,14 +858,14 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
// TODO: Aggregate splat to struct and array types
assert(canClassify(SubExpr->getType()));
- unsigned NumElts;
+ unsigned NumElems;
PrimType DestElemT;
QualType DestElemType;
if (const auto *VT = CE->getType()->getAs<VectorType>()) {
- NumElts = VT->getNumElements();
+ NumElems = VT->getNumElements();
DestElemType = VT->getElementType();
} else if (const auto *MT = CE->getType()->getAs<ConstantMatrixType>()) {
- NumElts = MT->getNumElementsFlattened();
+ NumElems = MT->getNumElementsFlattened();
DestElemType = MT->getElementType();
} else {
return false;
@@ -895,7 +895,7 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
if (!this->emitSetLocal(DestElemT, SrcOffset, CE))
return false;
- for (unsigned I = 0; I != NumElts; ++I) {
+ for (unsigned I = 0; I != NumElems; ++I) {
if (!this->emitGetLocal(DestElemT, SrcOffset, CE))
return false;
if (!this->emitInitElem(DestElemT, I, CE))
@@ -912,17 +912,17 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
QualType SrcType = SubExpr->getType();
QualType DestType = CE->getType();
- unsigned SrcNumElts;
+ unsigned SrcNumElems;
PrimType SrcElemT;
if (const auto *VT = SrcType->getAs<VectorType>()) {
- SrcNumElts = VT->getNumElements();
+ SrcNumElems = VT->getNumElements();
SrcElemT = classifyPrim(VT->getElementType());
} else if (const auto *MT = SrcType->getAs<ConstantMatrixType>()) {
- SrcNumElts = MT->getNumElementsFlattened();
+ SrcNumElems = MT->getNumElementsFlattened();
SrcElemT = classifyPrim(MT->getElementType());
} else if (const auto *AT = SrcType->getAsArrayTypeUnsafe()) {
if (const auto *CAT = dyn_cast<ConstantArrayType>(AT)) {
- SrcNumElts = CAT->getZExtSize();
+ SrcNumElems = CAT->getZExtSize();
SrcElemT = classifyPrim(CAT->getElementType());
} else {
return false;
@@ -931,18 +931,18 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
return false;
}
- unsigned DestNumElts;
+ unsigned DestNumElems;
PrimType DestElemT;
QualType DestElemType;
if (const auto *VT = DestType->getAs<VectorType>()) {
- DestNumElts = VT->getNumElements();
+ DestNumElems = VT->getNumElements();
DestElemType = VT->getElementType();
} else if (const auto *MT = DestType->getAs<ConstantMatrixType>()) {
- DestNumElts = MT->getNumElementsFlattened();
+ DestNumElems = MT->getNumElementsFlattened();
DestElemType = MT->getElementType();
} else if (const auto *AT = DestType->getAsArrayTypeUnsafe()) {
if (const auto *CAT = dyn_cast<ConstantArrayType>(AT)) {
- DestNumElts = CAT->getZExtSize();
+ DestNumElems = CAT->getZExtSize();
DestElemType = CAT->getElementType();
} else {
return false;
@@ -968,8 +968,8 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
if (!this->emitSetLocal(PT_Ptr, SrcOffset, CE))
return false;
- unsigned NumElts = std::min(SrcNumElts, DestNumElts);
- for (unsigned I = 0; I != NumElts; ++I) {
+ unsigned NumElems = std::min(SrcNumElems, DestNumElems);
+ for (unsigned I = 0; I != NumElems; ++I) {
if (!this->emitGetLocal(PT_Ptr, SrcOffset, CE))
return false;
if (!this->emitArrayElemPop(SrcElemT, I, CE))
@@ -1987,11 +1987,11 @@ bool Compiler<Emitter>::VisitImplicitValueInitExpr(
}
if (const auto *MT = E->getType()->getAs<ConstantMatrixType>()) {
- unsigned NumElts = MT->getNumElementsFlattened();
+ unsigned NumElems = MT->getNumElementsFlattened();
QualType ElemQT = MT->getElementType();
PrimType ElemT = classifyPrim(ElemQT);
- for (unsigned I = 0; I < NumElts; ++I) {
+ for (unsigned I = 0; I < NumElems; ++I) {
if (!this->visitZeroInitializer(ElemT, ElemQT, E))
return false;
if (!this->emitInitElem(ElemT, I, E))
@@ -2317,15 +2317,15 @@ bool Compiler<Emitter>::visitInitList(ArrayRef<const Expr *> Inits,
}
if (const auto *MT = QT->getAs<ConstantMatrixType>()) {
- unsigned NumElts = MT->getNumElementsFlattened();
- assert(Inits.size() == NumElts);
+ unsigned NumElems = MT->getNumElementsFlattened();
+ assert(Inits.size() == NumElems);
QualType ElemQT = MT->getElementType();
PrimType ElemT = classifyPrim(ElemQT);
// InitListExpr elements are in column-major order.
// Store in row-major order to match APValue convention.
- for (unsigned I = 0; I < NumElts; ++I) {
+ for (unsigned I = 0; I < NumElems; ++I) {
if (!this->visit(Inits[I]))
return false;
if (!this->emitInitElem(ElemT,
diff --git a/clang/lib/AST/ByteCode/Pointer.cpp b/clang/lib/AST/ByteCode/Pointer.cpp
index a569a4221cf2f..f4352e7edf5f8 100644
--- a/clang/lib/AST/ByteCode/Pointer.cpp
+++ b/clang/lib/AST/ByteCode/Pointer.cpp
@@ -939,11 +939,11 @@ std::optional<APValue> Pointer::toRValue(const Context &Ctx,
assert(Ptr.getFieldDesc()->isPrimitiveArray());
QualType ElemTy = MT->getElementType();
PrimType ElemT = *Ctx.classify(ElemTy);
- unsigned NumElts = MT->getNumElementsFlattened();
+ unsigned NumElems = MT->getNumElementsFlattened();
SmallVector<APValue> Values;
- Values.reserve(NumElts);
- for (unsigned I = 0; I != NumElts; ++I) {
+ Values.reserve(NumElems);
+ for (unsigned I = 0; I != NumElems; ++I) {
TYPE_SWITCH(ElemT,
{ Values.push_back(Ptr.elem<T>(I).toAPValue(ASTCtx)); });
}
>From 4bd726d2c2dd898b8403103cb4e0ab962f8b9de4 Mon Sep 17 00:00:00 2001
From: Deric Cheung <cheung.deric at gmail.com>
Date: Wed, 25 Feb 2026 22:10:49 -0800
Subject: [PATCH 4/7] Remove unnecessary SubExpr ptr check and load
This code was copied over from the CK_VectorSplat case but is not needed
for CK_HLSLAggregateSplatCast. The `classifyPrim(SubExpr) == PT_Ptr &&
!this->emitLoadPop(SrcElemT, CE)` is not necessary because an lvalue to
rvalue conversion is always inserted with an aggregate splat cast, so
the SubExpr will never be a ptr.
See https://github.com/llvm/llvm-project/blob/143664fcd3df825befdb9586151d53aefef3d7d0/clang/lib/Sema/SemaCast.cpp#L2939-L2940
for confirmation that this is the case.
---
clang/lib/AST/ByteCode/Compiler.cpp | 2 --
1 file changed, 2 deletions(-)
diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp
index 32ec16570a5b6..9cdaa3cc31288 100644
--- a/clang/lib/AST/ByteCode/Compiler.cpp
+++ b/clang/lib/AST/ByteCode/Compiler.cpp
@@ -886,8 +886,6 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
if (!this->visit(SubExpr))
return false;
- if (classifyPrim(SubExpr) == PT_Ptr && !this->emitLoadPop(SrcElemT, CE))
- return false;
if (SrcElemT != DestElemT) {
if (!this->emitPrimCast(SrcElemT, DestElemT, DestElemType, CE))
return false;
>From 952f2cb56d629a6f07ddce29a0b92ac5bcd3dc04 Mon Sep 17 00:00:00 2001
From: Deric Cheung <cheung.deric at gmail.com>
Date: Wed, 25 Feb 2026 23:16:09 -0800
Subject: [PATCH 5/7] Support scalar DestType for HLSLElementwiseCast
Also edit the comments of HLSLElementwiseCast and
HLSLAggregateSplatCasts to be more clear.
Assisted-by: claude-opus-4.6
---
clang/lib/AST/ByteCode/Compiler.cpp | 22 +++++++++++++++----
.../BuiltinMatrix/MatrixConstantExpr.hlsl | 2 ++
2 files changed, 20 insertions(+), 4 deletions(-)
diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp
index 9cdaa3cc31288..08c39538107df 100644
--- a/clang/lib/AST/ByteCode/Compiler.cpp
+++ b/clang/lib/AST/ByteCode/Compiler.cpp
@@ -854,7 +854,9 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
}
case CK_HLSLAggregateSplatCast: {
- // Aggregate splat cast: convert a scalar value to one of an aggregate type.
+ // Aggregate splat cast: convert a scalar value to one of an aggregate type,
+ // inserting casts when necessary to convert the scalar to the aggregate's
+ // element type(s).
// TODO: Aggregate splat to struct and array types
assert(canClassify(SubExpr->getType()));
@@ -904,7 +906,8 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
case CK_HLSLElementwiseCast: {
// Elementwise cast: flatten source elements of one aggregate type and store
- // to a destination aggregate type of the same or fewer number of elements.
+ // to a destination scalar or aggregate type of the same or fewer number of
+ // elements, while inserting casts as necessary.
// TODO: Elementwise cast to structs, nested arrays, and arrays of composite
// types
QualType SrcType = SubExpr->getType();
@@ -945,14 +948,25 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
} else {
return false;
}
+ } else if (classify(DestType)) {
+ // Scalar destination: extract element 0 and cast.
+ PrimType DestT = classifyPrim(DestType);
+ if (!this->visit(SubExpr))
+ return false;
+ if (!this->emitArrayElemPop(SrcElemT, 0, CE))
+ return false;
+ if (SrcElemT != DestT) {
+ if (!this->emitPrimCast(SrcElemT, DestT, DestType, CE))
+ return false;
+ }
+ return true;
} else {
return false;
}
DestElemT = classifyPrim(DestElemType);
if (!Initializing) {
- UnsignedOrNone LocalIndex =
- classify(DestType) ? allocateLocal(CE) : allocateTemporary(CE);
+ UnsignedOrNone LocalIndex = allocateTemporary(CE);
if (!LocalIndex)
return false;
if (!this->emitGetPtrLocal(*LocalIndex, CE))
diff --git a/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl b/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl
index 608af75bff4bb..2c55e1a0ee4b3 100644
--- a/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl
+++ b/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl
@@ -45,6 +45,8 @@ export void fn() {
_Static_assert(FA4[1] == 2.5, "Woo!");
_Static_assert(FA4[2] == 3.5, "Woo!");
_Static_assert(FA4[3] == 4.5, "Woo!");
+ constexpr float F = (float)FA4;
+ _Static_assert(F == 1.5, "Woo!");
}
// Array cast to matrix to vector
>From 974eec6611da919c1489f9d498e9be5913723660 Mon Sep 17 00:00:00 2001
From: Deric Cheung <cheung.deric at gmail.com>
Date: Wed, 25 Feb 2026 23:27:32 -0800
Subject: [PATCH 6/7] Change < to != in loop guards
---
clang/lib/AST/ByteCode/Compiler.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp
index 08c39538107df..914da43a3d714 100644
--- a/clang/lib/AST/ByteCode/Compiler.cpp
+++ b/clang/lib/AST/ByteCode/Compiler.cpp
@@ -2003,7 +2003,7 @@ bool Compiler<Emitter>::VisitImplicitValueInitExpr(
QualType ElemQT = MT->getElementType();
PrimType ElemT = classifyPrim(ElemQT);
- for (unsigned I = 0; I < NumElems; ++I) {
+ for (unsigned I = 0; I != NumElems; ++I) {
if (!this->visitZeroInitializer(ElemT, ElemQT, E))
return false;
if (!this->emitInitElem(ElemT, I, E))
@@ -2337,7 +2337,7 @@ bool Compiler<Emitter>::visitInitList(ArrayRef<const Expr *> Inits,
// InitListExpr elements are in column-major order.
// Store in row-major order to match APValue convention.
- for (unsigned I = 0; I < NumElems; ++I) {
+ for (unsigned I = 0; I != NumElems; ++I) {
if (!this->visit(Inits[I]))
return false;
if (!this->emitInitElem(ElemT,
>From eb43998a7e2e081ced9c23fcf45307d26c74e052 Mon Sep 17 00:00:00 2001
From: Deric Cheung <cheung.deric at gmail.com>
Date: Thu, 26 Feb 2026 09:04:28 -0800
Subject: [PATCH 7/7] Replace PrimType DestT with OptPrimType DestT
---
clang/lib/AST/ByteCode/Compiler.cpp | 7 +++----
1 file changed, 3 insertions(+), 4 deletions(-)
diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp
index 914da43a3d714..2f7026b961bbd 100644
--- a/clang/lib/AST/ByteCode/Compiler.cpp
+++ b/clang/lib/AST/ByteCode/Compiler.cpp
@@ -948,15 +948,14 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
} else {
return false;
}
- } else if (classify(DestType)) {
+ } else if (OptPrimType DestT = classify(DestType)) {
// Scalar destination: extract element 0 and cast.
- PrimType DestT = classifyPrim(DestType);
if (!this->visit(SubExpr))
return false;
if (!this->emitArrayElemPop(SrcElemT, 0, CE))
return false;
- if (SrcElemT != DestT) {
- if (!this->emitPrimCast(SrcElemT, DestT, DestType, CE))
+ if (SrcElemT != *DestT) {
+ if (!this->emitPrimCast(SrcElemT, *DestT, DestType, CE))
return false;
}
return true;
More information about the llvm-branch-commits
mailing list