[clang] [HLSL][Matrix] Add APValue and ConstExpr evaluator support for matrices (PR #178762)
Deric C. via cfe-commits
cfe-commits at lists.llvm.org
Thu Jan 29 14:51:21 PST 2026
https://github.com/Icohedron updated https://github.com/llvm/llvm-project/pull/178762
>From 6fa71f2c548da5174ee2fff267e0afaea786d828 Mon Sep 17 00:00:00 2001
From: Deric Cheung <cheung.deric at gmail.com>
Date: Thu, 29 Jan 2026 12:20:47 -0800
Subject: [PATCH 1/4] Add matrix APValue and ConstExpr evaluator for HLSL
Assisted-by: claude-opus-4.5
---
clang/include/clang/AST/APValue.h | 72 ++++++-
clang/lib/AST/APValue.cpp | 33 ++++
clang/lib/AST/ExprConstant.cpp | 178 +++++++++++++++++-
clang/lib/AST/Type.cpp | 4 +
clang/lib/CodeGen/CGExprConstant.cpp | 30 +++
.../BuiltinMatrix/MatrixConstantExpr.hlsl | 118 ++++++++++++
6 files changed, 427 insertions(+), 8 deletions(-)
create mode 100644 clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl
diff --git a/clang/include/clang/AST/APValue.h b/clang/include/clang/AST/APValue.h
index 8a2d6d434792a..61a8a6c9d1d25 100644
--- a/clang/include/clang/AST/APValue.h
+++ b/clang/include/clang/AST/APValue.h
@@ -136,6 +136,7 @@ class APValue {
ComplexFloat,
LValue,
Vector,
+ Matrix,
Array,
Struct,
Union,
@@ -275,6 +276,15 @@ class APValue {
Vec &operator=(const Vec &) = delete;
~Vec() { delete[] Elts; }
};
+ struct Mat {
+ APValue *Elts = nullptr;
+ unsigned NumRows = 0;
+ unsigned NumCols = 0;
+ Mat() = default;
+ Mat(const Mat &) = delete;
+ Mat &operator=(const Mat &) = delete;
+ ~Mat() { delete[] Elts; }
+ };
struct Arr {
APValue *Elts;
unsigned NumElts, ArrSize;
@@ -308,8 +318,9 @@ class APValue {
// We ensure elsewhere that Data is big enough for LV and MemberPointerData.
typedef llvm::AlignedCharArrayUnion<void *, APSInt, APFloat, ComplexAPSInt,
- ComplexAPFloat, Vec, Arr, StructData,
- UnionData, AddrLabelDiffData> DataType;
+ ComplexAPFloat, Vec, Mat, Arr, StructData,
+ UnionData, AddrLabelDiffData>
+ DataType;
static const size_t DataSize = sizeof(DataType);
DataType Data;
@@ -341,6 +352,13 @@ class APValue {
: Kind(None), AllowConstexprUnknown(false) {
MakeVector(); setVector(E, N);
}
+ /// Creates a matrix APValue with given dimensions. The elements
+ /// are read from \p E and assumed to be in column-major order.
+ explicit APValue(const APValue *E, unsigned NumRows, unsigned NumCols)
+ : Kind(None), AllowConstexprUnknown(false) {
+ MakeMatrix();
+ setMatrix(E, NumRows, NumCols);
+ }
/// Creates an integer complex APValue with the given real and imaginary
/// values.
APValue(APSInt R, APSInt I) : Kind(None), AllowConstexprUnknown(false) {
@@ -471,6 +489,7 @@ class APValue {
bool isComplexFloat() const { return Kind == ComplexFloat; }
bool isLValue() const { return Kind == LValue; }
bool isVector() const { return Kind == Vector; }
+ bool isMatrix() const { return Kind == Matrix; }
bool isArray() const { return Kind == Array; }
bool isStruct() const { return Kind == Struct; }
bool isUnion() const { return Kind == Union; }
@@ -573,6 +592,37 @@ class APValue {
return ((const Vec *)(const void *)&Data)->NumElts;
}
+ unsigned getMatrixNumRows() const {
+ assert(isMatrix() && "Invalid accessor");
+ return ((const Mat *)(const void *)&Data)->NumRows;
+ }
+ unsigned getMatrixNumCols() const {
+ assert(isMatrix() && "Invalid accessor");
+ return ((const Mat *)(const void *)&Data)->NumCols;
+ }
+ unsigned getMatrixNumElements() const {
+ return getMatrixNumRows() * getMatrixNumCols();
+ }
+ APValue &getMatrixElt(unsigned Idx) {
+ assert(isMatrix() && "Invalid accessor");
+ assert(Idx < getMatrixNumElements() && "Index out of range");
+ return ((Mat *)(char *)&Data)->Elts[Idx];
+ }
+ const APValue &getMatrixElt(unsigned Idx) const {
+ return const_cast<APValue *>(this)->getMatrixElt(Idx);
+ }
+ APValue &getMatrixElt(unsigned Row, unsigned Col) {
+ assert(isMatrix() && "Invalid accessor");
+ assert(Row < getMatrixNumRows() && "Row index out of range");
+ assert(Col < getMatrixNumCols() && "Column index out of range");
+ // Matrix elements are stored in column-major order.
+ unsigned I = Col * getMatrixNumRows() + Row;
+ return ((Mat *)(char *)&Data)->Elts[I];
+ }
+ const APValue &getMatrixElt(unsigned Row, unsigned Col) const {
+ return const_cast<APValue *>(this)->getMatrixElt(Row, Col);
+ }
+
APValue &getArrayInitializedElt(unsigned I) {
assert(isArray() && "Invalid accessor");
assert(I < getArrayInitializedElts() && "Index out of range");
@@ -668,6 +718,11 @@ class APValue {
for (unsigned i = 0; i != N; ++i)
InternalElts[i] = E[i];
}
+ void setMatrix(const APValue *E, unsigned NumRows, unsigned NumCols) {
+ MutableArrayRef<APValue> InternalElts = setMatrixUninit(NumRows, NumCols);
+ for (unsigned i = 0; i != NumRows * NumCols; ++i)
+ InternalElts[i] = E[i];
+ }
void setComplexInt(APSInt R, APSInt I) {
assert(R.getBitWidth() == I.getBitWidth() &&
"Invalid complex int (type mismatch).");
@@ -716,6 +771,11 @@ class APValue {
new ((void *)(char *)&Data) Vec();
Kind = Vector;
}
+ void MakeMatrix() {
+ assert(isAbsent() && "Bad state change");
+ new ((void *)(char *)&Data) Mat();
+ Kind = Matrix;
+ }
void MakeComplexInt() {
assert(isAbsent() && "Bad state change");
new ((void *)(char *)&Data) ComplexAPSInt();
@@ -757,6 +817,14 @@ class APValue {
V->NumElts = N;
return {V->Elts, V->NumElts};
}
+ MutableArrayRef<APValue> setMatrixUninit(unsigned NumRows, unsigned NumCols) {
+ assert(isMatrix() && "Invalid accessor");
+ Mat *M = ((Mat *)(char *)&Data);
+ M->Elts = new APValue[NumRows * NumCols];
+ M->NumRows = NumRows;
+ M->NumCols = NumCols;
+ return {M->Elts, NumRows * NumCols};
+ }
MutableArrayRef<LValuePathEntry>
setLValueUninit(LValueBase B, const CharUnits &O, unsigned Size,
bool OnePastTheEnd, bool IsNullPtr);
diff --git a/clang/lib/AST/APValue.cpp b/clang/lib/AST/APValue.cpp
index 2e1c8eb3726cf..d1853dfd6a7b4 100644
--- a/clang/lib/AST/APValue.cpp
+++ b/clang/lib/AST/APValue.cpp
@@ -333,6 +333,11 @@ APValue::APValue(const APValue &RHS)
setVector(((const Vec *)(const char *)&RHS.Data)->Elts,
RHS.getVectorLength());
break;
+ case Matrix:
+ MakeMatrix();
+ setMatrix(((const Mat *)(const char *)&RHS.Data)->Elts,
+ RHS.getMatrixNumRows(), RHS.getMatrixNumCols());
+ break;
case ComplexInt:
MakeComplexInt();
setComplexInt(RHS.getComplexIntReal(), RHS.getComplexIntImag());
@@ -414,6 +419,8 @@ void APValue::DestroyDataAndMakeUninit() {
((APFixedPoint *)(char *)&Data)->~APFixedPoint();
else if (Kind == Vector)
((Vec *)(char *)&Data)->~Vec();
+ else if (Kind == Matrix)
+ ((Mat *)(char *)&Data)->~Mat();
else if (Kind == ComplexInt)
((ComplexAPSInt *)(char *)&Data)->~ComplexAPSInt();
else if (Kind == ComplexFloat)
@@ -444,6 +451,7 @@ bool APValue::needsCleanup() const {
case Union:
case Array:
case Vector:
+ case Matrix:
return true;
case Int:
return getInt().needsCleanup();
@@ -580,6 +588,12 @@ void APValue::Profile(llvm::FoldingSetNodeID &ID) const {
getVectorElt(I).Profile(ID);
return;
+ case Matrix:
+ for (unsigned R = 0, N = getMatrixNumRows(); R != N; ++R)
+ for (unsigned C = 0, M = getMatrixNumCols(); C != M; ++C)
+ getMatrixElt(R, C).Profile(ID);
+ return;
+
case Int:
profileIntValue(ID, getInt());
return;
@@ -747,6 +761,24 @@ void APValue::printPretty(raw_ostream &Out, const PrintingPolicy &Policy,
Out << '}';
return;
}
+ case APValue::Matrix: {
+ const auto *MT = Ty->castAs<ConstantMatrixType>();
+ QualType ElemTy = MT->getElementType();
+ Out << '{';
+ for (unsigned R = 0; R < getMatrixNumRows(); ++R) {
+ if (R != 0)
+ Out << ", ";
+ Out << '{';
+ for (unsigned C = 0; C < getMatrixNumCols(); ++C) {
+ if (C != 0)
+ Out << ", ";
+ getMatrixElt(R, C).printPretty(Out, Policy, ElemTy, Ctx);
+ }
+ Out << '}';
+ }
+ Out << '}';
+ return;
+ }
case APValue::ComplexInt:
Out << getComplexIntReal() << "+" << getComplexIntImag() << "i";
return;
@@ -1139,6 +1171,7 @@ LinkageInfo LinkageComputer::getLVForValue(const APValue &V,
case APValue::ComplexInt:
case APValue::ComplexFloat:
case APValue::Vector:
+ case APValue::Matrix:
break;
case APValue::AddrLabelDiff:
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 73768f7dd612b..d5107c05f8007 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -1886,6 +1886,7 @@ static bool EvaluateIntegerOrLValue(const Expr *E, APValue &Result,
EvalInfo &Info);
static bool EvaluateFloat(const Expr *E, APFloat &Result, EvalInfo &Info);
static bool EvaluateComplex(const Expr *E, ComplexValue &Res, EvalInfo &Info);
+static bool EvaluateMatrix(const Expr *E, APValue &Result, EvalInfo &Info);
static bool EvaluateAtomic(const Expr *E, const LValue *This, APValue &Result,
EvalInfo &Info);
static bool EvaluateAsRValue(EvalInfo &Info, const Expr *E, APValue &Result);
@@ -2717,6 +2718,7 @@ static bool HandleConversionToBool(const APValue &Val, bool &Result) {
Result = Val.getMemberPointerDecl();
return true;
case APValue::Vector:
+ case APValue::Matrix:
case APValue::Array:
case APValue::Struct:
case APValue::Union:
@@ -4043,6 +4045,12 @@ static unsigned elementwiseSize(EvalInfo &Info, QualType BaseTy) {
Size += NumEl;
continue;
}
+ if (Type->isConstantMatrixType()) {
+ unsigned NumEl =
+ Type->castAs<ConstantMatrixType>()->getNumElementsFlattened();
+ Size += NumEl;
+ continue;
+ }
if (Type->isConstantArrayType()) {
QualType ElTy = cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))
->getElementType();
@@ -4093,6 +4101,11 @@ static bool hlslAggSplatHelper(EvalInfo &Info, const Expr *E, APValue &SrcVal,
SrcTy = SrcTy->castAs<VectorType>()->getElementType();
SrcVal = SrcVal.getVectorElt(0);
}
+ if (SrcVal.isMatrix()) {
+ assert(SrcTy->isConstantMatrixType() && "Type mismatch.");
+ SrcTy = SrcTy->castAs<ConstantMatrixType>()->getElementType();
+ SrcVal = SrcVal.getMatrixElt(0, 0);
+ }
return true;
}
@@ -4122,6 +4135,22 @@ static bool flattenAPValue(EvalInfo &Info, const Expr *E, APValue Value,
}
continue;
}
+ if (Work.isMatrix()) {
+ assert(Type->isConstantMatrixType() && "Type mismatch.");
+ const auto *MT = Type->castAs<ConstantMatrixType>();
+ QualType ElTy = MT->getElementType();
+ // Matrix elements are flattened in row-major order.
+ for (unsigned Row = 0; Row < Work.getMatrixNumRows() && Populated < Size;
+ Row++) {
+ for (unsigned Col = 0;
+ Col < Work.getMatrixNumCols() && Populated < Size; Col++) {
+ Elements.push_back(Work.getMatrixElt(Row, Col));
+ Types.push_back(ElTy);
+ Populated++;
+ }
+ }
+ continue;
+ }
if (Work.isArray()) {
assert(Type->isConstantArrayType() && "Type mismatch.");
QualType ElTy = cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))
@@ -7844,6 +7873,7 @@ class APValueToBufferConverter {
case APValue::FixedPoint:
// FIXME: We should support these.
+ case APValue::Matrix:
case APValue::Union:
case APValue::MemberPointer:
case APValue::AddrLabelDiff: {
@@ -11769,8 +11799,17 @@ bool VectorExprEvaluator::VisitCastExpr(const CastExpr *E) {
return Success(Elements, E);
}
case CK_HLSLMatrixTruncation: {
- // TODO: See #168935. Add matrix truncation support to expr constant.
- return Error(E);
+ // Matrix truncation occurs in row-major order.
+ APValue Val;
+ if (!EvaluateMatrix(SE, Val, Info))
+ return Error(E);
+ SmallVector<APValue, 16> Elements;
+ for (unsigned Row = 0;
+ Row < Val.getMatrixNumRows() && Elements.size() < NElts; Row++)
+ for (unsigned Col = 0;
+ Col < Val.getMatrixNumCols() && Elements.size() < NElts; Col++)
+ Elements.push_back(Val.getMatrixElt(Row, Col));
+ return Success(Elements, E);
}
case CK_HLSLAggregateSplatCast: {
APValue Val;
@@ -14604,6 +14643,126 @@ bool VectorExprEvaluator::VisitShuffleVectorExpr(const ShuffleVectorExpr *E) {
return Success(APValue(ResultElements.data(), ResultElements.size()), E);
}
+//===----------------------------------------------------------------------===//
+// Matrix Evaluation
+//===----------------------------------------------------------------------===//
+
+namespace {
+class MatrixExprEvaluator : public ExprEvaluatorBase<MatrixExprEvaluator> {
+ APValue &Result;
+
+public:
+ MatrixExprEvaluator(EvalInfo &Info, APValue &Result)
+ : ExprEvaluatorBaseTy(Info), Result(Result) {}
+
+ bool Success(ArrayRef<APValue> M, unsigned NumRows, unsigned NumCols,
+ const Expr *E) {
+ assert(
+ M.size() ==
+ E->getType()->castAs<ConstantMatrixType>()->getNumElementsFlattened());
+ assert(M.size() == NumRows * NumCols);
+ // FIXME: remove this APValue copy.
+ Result = APValue(M.data(), NumRows, NumCols);
+ return true;
+ }
+ bool Success(const APValue &M, const Expr *E) {
+ assert(M.isMatrix() && "expected matrix");
+ Result = M;
+ return true;
+ }
+
+ bool VisitCastExpr(const CastExpr *E);
+ bool VisitInitListExpr(const InitListExpr *E);
+};
+} // end anonymous namespace
+
+static bool EvaluateMatrix(const Expr *E, APValue &Result, EvalInfo &Info) {
+ assert(E->isPRValue() && E->getType()->isConstantMatrixType() &&
+ "not a matrix prvalue");
+ return MatrixExprEvaluator(Info, Result).Visit(E);
+}
+
+bool MatrixExprEvaluator::VisitCastExpr(const CastExpr *E) {
+ const auto *MT = E->getType()->castAs<ConstantMatrixType>();
+ unsigned NumRows = MT->getNumRows();
+ unsigned NumCols = MT->getNumColumns();
+ unsigned NElts = NumRows * NumCols;
+ QualType EltTy = MT->getElementType();
+ const Expr *SE = E->getSubExpr();
+
+ switch (E->getCastKind()) {
+ case CK_HLSLAggregateSplatCast: {
+ APValue Val;
+ QualType ValTy;
+
+ if (!hlslAggSplatHelper(Info, SE, Val, ValTy))
+ return false;
+
+ APValue CastedVal;
+ const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+ if (!handleScalarCast(Info, FPO, E, ValTy, EltTy, Val, CastedVal))
+ return false;
+
+ SmallVector<APValue, 16> SplatEls(NElts, CastedVal);
+ return Success(SplatEls, NumRows, NumCols, E);
+ }
+ case CK_HLSLElementwiseCast: {
+ SmallVector<APValue> SrcVals;
+ SmallVector<QualType> SrcTypes;
+
+ if (!hlslElementwiseCastHelper(Info, SE, E->getType(), SrcVals, SrcTypes))
+ return false;
+
+ const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+ SmallVector<QualType, 16> DestTypes(NElts, EltTy);
+ SmallVector<APValue, 16> ResultEls(NElts);
+ if (!handleElementwiseCast(Info, E, FPO, SrcVals, SrcTypes, DestTypes,
+ ResultEls))
+ return false;
+ // ResultEls contains elements in row-major order, but APValue expects
+ // column-major order. Reorder the elements.
+ SmallVector<APValue, 16> ColMajorEls(NElts);
+ for (unsigned Row = 0; Row < NumRows; Row++)
+ for (unsigned Col = 0; Col < NumCols; Col++)
+ ColMajorEls[Col * NumRows + Row] =
+ std::move(ResultEls[Row * NumCols + Col]);
+ return Success(ColMajorEls, NumRows, NumCols, E);
+ }
+ default:
+ return ExprEvaluatorBaseTy::VisitCastExpr(E);
+ }
+}
+
+bool MatrixExprEvaluator::VisitInitListExpr(const InitListExpr *E) {
+ const auto *MT = E->getType()->castAs<ConstantMatrixType>();
+ unsigned NumRows = MT->getNumRows();
+ unsigned NumCols = MT->getNumColumns();
+ QualType EltTy = MT->getElementType();
+
+ assert(E->getNumInits() == NumRows * NumCols &&
+ "Expected number of elements in initializer list to match the number "
+ "of matrix elements");
+
+ SmallVector<APValue, 16> Elements;
+ Elements.reserve(NumRows * NumCols);
+
+ for (unsigned I = 0; I < E->getNumInits(); ++I) {
+ if (EltTy->isIntegerType()) {
+ llvm::APSInt IntVal;
+ if (!EvaluateInteger(E->getInit(I), IntVal, Info))
+ return false;
+ Elements.push_back(APValue(IntVal));
+ } else {
+ llvm::APFloat FloatVal(0.0);
+ if (!EvaluateFloat(E->getInit(I), FloatVal, Info))
+ return false;
+ Elements.push_back(APValue(FloatVal));
+ }
+ }
+
+ return Success(Elements, NumRows, NumCols, E);
+}
+
//===----------------------------------------------------------------------===//
// Array Evaluation
//===----------------------------------------------------------------------===//
@@ -18942,8 +19101,10 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
return Success(Val.getVectorElt(0), E);
}
case CK_HLSLMatrixTruncation: {
- // TODO: See #168935. Add matrix truncation support to expr constant.
- return Error(E);
+ APValue Val;
+ if (!EvaluateMatrix(SubExpr, Val, Info))
+ return Error(E);
+ return Success(Val.getMatrixElt(0, 0), E);
}
case CK_HLSLElementwiseCast: {
SmallVector<APValue> SrcVals;
@@ -19539,8 +19700,10 @@ bool FloatExprEvaluator::VisitCastExpr(const CastExpr *E) {
return Success(Val.getVectorElt(0), E);
}
case CK_HLSLMatrixTruncation: {
- // TODO: See #168935. Add matrix truncation support to expr constant.
- return Error(E);
+ APValue Val;
+ if (!EvaluateMatrix(SubExpr, Val, Info))
+ return Error(E);
+ return Success(Val.getMatrixElt(0, 0), E);
}
case CK_HLSLElementwiseCast: {
SmallVector<APValue> SrcVals;
@@ -20439,6 +20602,9 @@ static bool Evaluate(APValue &Result, EvalInfo &Info, const Expr *E) {
} else if (T->isVectorType()) {
if (!EvaluateVector(E, Result, Info))
return false;
+ } else if (T->isConstantMatrixType()) {
+ if (!EvaluateMatrix(E, Result, Info))
+ return false;
} else if (T->isIntegralOrEnumerationType()) {
if (!IntExprEvaluator(Info, Result).Visit(E))
return false;
diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp
index 53082bcf78f6a..7211f0eb96523 100644
--- a/clang/lib/AST/Type.cpp
+++ b/clang/lib/AST/Type.cpp
@@ -3023,6 +3023,10 @@ bool Type::isLiteralType(const ASTContext &Ctx) const {
if (BaseTy->isScalarType() || BaseTy->isVectorType() ||
BaseTy->isAnyComplexType())
return true;
+ // Matrices with constant numbers of rows and columns are also literal types
+ // in HLSL.
+ if (Ctx.getLangOpts().HLSL && BaseTy->isConstantMatrixType())
+ return true;
// -- a reference type; or
if (BaseTy->isReferenceType())
return true;
diff --git a/clang/lib/CodeGen/CGExprConstant.cpp b/clang/lib/CodeGen/CGExprConstant.cpp
index 0eec4dba4824a..580c95279aaa9 100644
--- a/clang/lib/CodeGen/CGExprConstant.cpp
+++ b/clang/lib/CodeGen/CGExprConstant.cpp
@@ -2510,6 +2510,36 @@ ConstantEmitter::tryEmitPrivate(const APValue &Value, QualType DestType,
}
return llvm::ConstantVector::get(Inits);
}
+ case APValue::Matrix: {
+ const auto *MT = DestType->castAs<ConstantMatrixType>();
+ unsigned NumRows = Value.getMatrixNumRows();
+ unsigned NumCols = Value.getMatrixNumCols();
+ unsigned NumElts = NumRows * NumCols;
+ SmallVector<llvm::Constant *, 16> Inits(NumElts);
+
+ bool IsRowMajor = CGM.getLangOpts().getDefaultMatrixMemoryLayout() ==
+ LangOptions::MatrixMemoryLayout::MatrixRowMajor;
+
+ for (unsigned Row = 0; Row != NumRows; ++Row) {
+ for (unsigned Col = 0; Col != NumCols; ++Col) {
+ const APValue &Elt = Value.getMatrixElt(Row, Col);
+ // Compute flat index based on memory layout.
+ unsigned Idx = IsRowMajor ? Row * NumCols + Col : Col * NumRows + Row;
+ if (Elt.isInt())
+ Inits[Idx] =
+ llvm::ConstantInt::get(CGM.getLLVMContext(), Elt.getInt());
+ else if (Elt.isFloat())
+ Inits[Idx] =
+ llvm::ConstantFP::get(CGM.getLLVMContext(), Elt.getFloat());
+ else if (Elt.isIndeterminate())
+ Inits[Idx] = llvm::UndefValue::get(
+ CGM.getTypes().ConvertType(MT->getElementType()));
+ else
+ llvm_unreachable("unsupported matrix element type");
+ }
+ }
+ return llvm::ConstantVector::get(Inits);
+ }
case APValue::AddrLabelDiff: {
const AddrLabelExpr *LHSExpr = Value.getAddrLabelDiffLHS();
const AddrLabelExpr *RHSExpr = Value.getAddrLabelDiffRHS();
diff --git a/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl b/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl
new file mode 100644
index 0000000000000..7c4a14beaea8f
--- /dev/null
+++ b/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixConstantExpr.hlsl
@@ -0,0 +1,118 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -std=hlsl202x -fmatrix-memory-layout=column-major -verify %s
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -std=hlsl202x -fmatrix-memory-layout=row-major -verify %s
+
+// expected-no-diagnostics
+
+// Matrix subscripting is not currently supported with matrix constexpr. So all
+// tests involve casting to another type to determine if the output is correct.
+
+export void fn() {
+
+ // Matrix truncation to int - should get element at (0,0)
+ {
+ constexpr int2x3 IM = {1, 2, 3,
+ 4, 5, 6};
+ _Static_assert((int)IM == 1, "Woo!");
+ }
+
+ // Matrix splat to vector
+ {
+ constexpr bool2x2 BM2x2 = true;
+ constexpr bool4 BV4 = (bool4)BM2x2;
+ _Static_assert(BV4.x == true, "Woo!");
+ _Static_assert(BV4.y == true, "Woo!");
+ _Static_assert(BV4.z == true, "Woo!");
+ _Static_assert(BV4.w == true, "Woo!");
+ }
+
+ // Matrix cast to vector
+ {
+ constexpr float2x2 FM2x2 = {1.5, 2.5, 3.5, 4.5};
+ constexpr float4 FV4 = (float4)FM2x2;
+ _Static_assert(FV4.x == 1.5, "Woo!");
+ _Static_assert(FV4.y == 2.5, "Woo!");
+ _Static_assert(FV4.z == 3.5, "Woo!");
+ _Static_assert(FV4.w == 4.5, "Woo!");
+ }
+
+ // Matrix cast to array
+ {
+ constexpr float2x2 FM2x2 = {1.5, 2.5, 3.5, 4.5};
+ constexpr float FA4[4] = (float[4])FM2x2;
+ _Static_assert(FA4[0] == 1.5, "Woo!");
+ _Static_assert(FA4[1] == 2.5, "Woo!");
+ _Static_assert(FA4[2] == 3.5, "Woo!");
+ _Static_assert(FA4[3] == 4.5, "Woo!");
+ }
+
+ // Array cast to matrix to vector
+ {
+ constexpr int IA4[4] = {1, 2, 3, 4};
+ constexpr int2x2 IM2x2 = (int2x2)IA4;
+ constexpr int4 IV4 = (int4)IM2x2;
+ _Static_assert(IV4.x == 1, "Woo!");
+ _Static_assert(IV4.y == 2, "Woo!");
+ _Static_assert(IV4.z == 3, "Woo!");
+ _Static_assert(IV4.w == 4, "Woo!");
+ }
+
+ // Vector cast to matrix to vector
+ {
+ constexpr bool4 BV4_0 = {true, false, true, false};
+ constexpr bool2x2 BM2x2 = (bool2x2)BV4_0;
+ constexpr bool4 BV4 = (bool4)BM2x2;
+ _Static_assert(BV4.x == true, "Woo!");
+ _Static_assert(BV4.y == false, "Woo!");
+ _Static_assert(BV4.z == true, "Woo!");
+ _Static_assert(BV4.w == false, "Woo!");
+ }
+
+ // Matrix truncation to vector
+ {
+ constexpr int3x2 IM3x2 = { 1, 2,
+ 3, 4,
+ 5, 6};
+ constexpr int4 IV4 = (int4)IM3x2;
+ _Static_assert(IV4.x == 1, "Woo!");
+ _Static_assert(IV4.y == 2, "Woo!");
+ _Static_assert(IV4.z == 3, "Woo!");
+ _Static_assert(IV4.w == 4, "Woo!");
+ }
+
+ // Matrix truncation to array
+ {
+ constexpr int3x2 IM3x2 = { 1, 2,
+ 3, 4,
+ 5, 6};
+ constexpr int IA4[4] = (int[4])IM3x2;
+ _Static_assert(IA4[0] == 1, "Woo!");
+ _Static_assert(IA4[1] == 2, "Woo!");
+ _Static_assert(IA4[2] == 3, "Woo!");
+ _Static_assert(IA4[3] == 4, "Woo!");
+ }
+
+ // Array cast to matrix truncation to vector
+ {
+ constexpr float FA16[16] = { 1.0, 2.0, 3.0, 4.0,
+ 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0,
+ 13.0, 14.0, 15.0, 16.0};
+ constexpr float4x4 FM4x4 = (float4x4)FA16;
+ constexpr float4 FV4 = (float4)FM4x4;
+ _Static_assert(FV4.x == 1.0, "Woo!");
+ _Static_assert(FV4.y == 2.0, "Woo!");
+ _Static_assert(FV4.z == 3.0, "Woo!");
+ _Static_assert(FV4.w == 4.0, "Woo!");
+ }
+
+ // Vector cast to matrix truncation to vector
+ {
+ constexpr bool4 BV4 = {true, false, true, false};
+ constexpr bool2x2 BM2x2 = (bool2x2)BV4;
+ constexpr bool3 BV3 = (bool3)BM2x2;
+ _Static_assert(BV4.x == true, "Woo!");
+ _Static_assert(BV4.y == false, "Woo!");
+ _Static_assert(BV4.z == true, "Woo!");
+ }
+
+}
>From 59f53e8be99b4f453ec2accbf57648cfc404beda Mon Sep 17 00:00:00 2001
From: Deric Cheung <cheung.deric at gmail.com>
Date: Thu, 29 Jan 2026 09:19:34 -0800
Subject: [PATCH 2/4] Fix various "enumeration value 'Matrix' not handled in
switch" warnings
Assisted-by: claude-opus-4.5
---
clang/include/clang/AST/PropertiesBase.td | 23 ++++++++++++++++
clang/lib/AST/ASTImporter.cpp | 3 +++
clang/lib/AST/ItaniumMangle.cpp | 6 +++++
clang/lib/AST/MicrosoftMangle.cpp | 5 ++++
clang/lib/AST/TextNodeDumper.cpp | 14 ++++++++++
clang/lib/Sema/SemaTemplate.cpp | 3 +++
.../AST/HLSL/ast-dump-APValue-matrix.hlsl | 26 +++++++++++++++++++
7 files changed, 80 insertions(+)
create mode 100644 clang/test/AST/HLSL/ast-dump-APValue-matrix.hlsl
diff --git a/clang/include/clang/AST/PropertiesBase.td b/clang/include/clang/AST/PropertiesBase.td
index 5b10127526e4e..d3fdcef04fe97 100644
--- a/clang/include/clang/AST/PropertiesBase.td
+++ b/clang/include/clang/AST/PropertiesBase.td
@@ -351,6 +351,29 @@ let Class = PropertyTypeCase<APValue, "Vector"> in {
return result;
}]>;
}
+let Class = PropertyTypeCase<APValue, "Matrix"> in {
+ def : ReadHelper<[{
+ SmallVector<APValue, 16> buffer;
+ unsigned numElts = node.getMatrixNumElements();
+ for (unsigned i = 0; i < numElts; ++i)
+ buffer.push_back(node.getMatrixElt(i));
+ }]>;
+ def : Property<"numRows", UInt32> {
+ let Read = [{ node.getMatrixNumRows() }];
+ }
+ def : Property<"numCols", UInt32> {
+ let Read = [{ node.getMatrixNumCols() }];
+ }
+ def : Property<"elements", Array<APValue>> { let Read = [{ buffer }]; }
+ def : Creator<[{
+ APValue result;
+ result.MakeMatrix();
+ (void)result.setMatrixUninit(numRows, numCols);
+ for (unsigned i = 0; i < elements.size(); i++)
+ result.getMatrixElt(i) = elements[i];
+ return result;
+ }]>;
+}
let Class = PropertyTypeCase<APValue, "Array"> in {
def : ReadHelper<[{
SmallVector<APValue, 4> buffer{};
diff --git a/clang/lib/AST/ASTImporter.cpp b/clang/lib/AST/ASTImporter.cpp
index 101ab2c40973b..e61a2db8072cf 100644
--- a/clang/lib/AST/ASTImporter.cpp
+++ b/clang/lib/AST/ASTImporter.cpp
@@ -10616,6 +10616,9 @@ ASTNodeImporter::ImportAPValue(const APValue &FromValue) {
Elts.data(), FromValue.getVectorLength());
break;
}
+ case APValue::Matrix:
+ // Matrix values cannot currently arise in APValue import contexts.
+ llvm_unreachable("Matrix APValue import not yet supported");
case APValue::Array:
Result.MakeArray(FromValue.getArrayInitializedElts(),
FromValue.getArraySize());
diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp
index fa28c0d444cc4..0c7b9c7f99c84 100644
--- a/clang/lib/AST/ItaniumMangle.cpp
+++ b/clang/lib/AST/ItaniumMangle.cpp
@@ -6463,6 +6463,9 @@ static bool isZeroInitialized(QualType T, const APValue &V) {
return true;
}
+ case APValue::Matrix:
+ llvm_unreachable("Matrix APValues not yet supported");
+
case APValue::Int:
return !V.getInt();
@@ -6677,6 +6680,9 @@ void CXXNameMangler::mangleValueInTemplateArg(QualType T, const APValue &V,
break;
}
+ case APValue::Matrix:
+ llvm_unreachable("Matrix template argument mangling not yet supported");
+
case APValue::Int:
mangleIntegerLiteral(T, V.getInt());
break;
diff --git a/clang/lib/AST/MicrosoftMangle.cpp b/clang/lib/AST/MicrosoftMangle.cpp
index 551aa7bf3321c..40146b79b9c37 100644
--- a/clang/lib/AST/MicrosoftMangle.cpp
+++ b/clang/lib/AST/MicrosoftMangle.cpp
@@ -2158,6 +2158,11 @@ void MicrosoftCXXNameMangler::mangleTemplateArgValue(QualType T,
return;
}
+ case APValue::Matrix: {
+ Error("template argument (value type: matrix)");
+ return;
+ }
+
case APValue::AddrLabelDiff: {
Error("template argument (value type: address label diff)");
return;
diff --git a/clang/lib/AST/TextNodeDumper.cpp b/clang/lib/AST/TextNodeDumper.cpp
index 7bc0404db1bee..c3c8c63652f42 100644
--- a/clang/lib/AST/TextNodeDumper.cpp
+++ b/clang/lib/AST/TextNodeDumper.cpp
@@ -620,6 +620,7 @@ static bool isSimpleAPValue(const APValue &Value) {
case APValue::Vector:
case APValue::Array:
case APValue::Struct:
+ case APValue::Matrix:
return false;
case APValue::Union:
return isSimpleAPValue(Value.getUnionValue());
@@ -812,6 +813,19 @@ void TextNodeDumper::Visit(const APValue &Value, QualType Ty) {
return;
}
+ case APValue::Matrix: {
+ unsigned NumRows = Value.getMatrixNumRows();
+ unsigned NumCols = Value.getMatrixNumCols();
+ OS << "Matrix " << NumRows << "x" << NumCols;
+
+ dumpAPValueChildren(
+ Value, Ty,
+ [](const APValue &Value, unsigned Index) -> const APValue & {
+ return Value.getMatrixElt(Index);
+ },
+ Value.getMatrixNumElements(), "element", "elements");
+ return;
+ }
case APValue::Union: {
OS << "Union";
{
diff --git a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp
index 3497ff7856eed..3f28f97d9fed7 100644
--- a/clang/lib/Sema/SemaTemplate.cpp
+++ b/clang/lib/Sema/SemaTemplate.cpp
@@ -8127,6 +8127,9 @@ static Expr *BuildExpressionFromNonTypeTemplateArgumentValue(
return MakeInitList(Elts);
}
+ case APValue::Matrix:
+ llvm_unreachable("Matrix template argument expression not yet supported");
+
case APValue::None:
case APValue::Indeterminate:
llvm_unreachable("Unexpected APValue kind.");
diff --git a/clang/test/AST/HLSL/ast-dump-APValue-matrix.hlsl b/clang/test/AST/HLSL/ast-dump-APValue-matrix.hlsl
new file mode 100644
index 0000000000000..27ccb0088e5d9
--- /dev/null
+++ b/clang/test/AST/HLSL/ast-dump-APValue-matrix.hlsl
@@ -0,0 +1,26 @@
+// Test without serialization:
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -std=hlsl202x \
+// RUN: -ast-dump %s -ast-dump-filter Test \
+// RUN: | FileCheck --strict-whitespace %s
+//
+// Test with serialization:
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -std=hlsl202x -emit-pch -o %t %s
+// RUN: %clang_cc1 -x hlsl -triple dxil-pc-shadermodel6.6-library -finclude-default-header -std=hlsl202x \
+// RUN: -include-pch %t -ast-dump-all -ast-dump-filter Test /dev/null \
+// RUN: | sed -e "s/ <undeserialized declarations>//" -e "s/ imported//" \
+// RUN: | FileCheck --strict-whitespace %s
+
+export void Test() {
+ // Matrix with 4 elements (2x2), stored in column-major order
+ constexpr int2x2 mat2x2 = {1, 2, 3, 4};
+ // CHECK: VarDecl {{.*}} mat2x2 {{.*}} constexpr cinit
+ // CHECK-NEXT: |-value: Matrix 2x2
+ // CHECK-NEXT: | `-elements: Int 1, Int 3, Int 2, Int 4
+
+ // Matrix with 6 elements (3x2), stored in column-major order
+ constexpr float3x2 mat3x2 = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+ // CHECK: VarDecl {{.*}} mat3x2 {{.*}} constexpr cinit
+ // CHECK-NEXT: |-value: Matrix 3x2
+ // CHECK-NEXT: | |-elements: Float 1.000000e+00, Float 3.000000e+00, Float 5.000000e+00, Float 2.000000e+00
+ // CHECK-NEXT: | `-elements: Float 4.000000e+00, Float 6.000000e+00
+}
>From e3207ad76e5fa9f2b7de9c3ed96d7033afd4be53 Mon Sep 17 00:00:00 2001
From: Deric Cheung <cheung.deric at gmail.com>
Date: Thu, 29 Jan 2026 11:20:02 -0800
Subject: [PATCH 3/4] Fix BoolMatrix test due to APValue enabling memcpy
optimization
---
clang/test/CodeGenHLSL/BoolMatrix.hlsl | 16 +++++-----------
1 file changed, 5 insertions(+), 11 deletions(-)
diff --git a/clang/test/CodeGenHLSL/BoolMatrix.hlsl b/clang/test/CodeGenHLSL/BoolMatrix.hlsl
index 824b9656e6848..d61d48d3b74b5 100644
--- a/clang/test/CodeGenHLSL/BoolMatrix.hlsl
+++ b/clang/test/CodeGenHLSL/BoolMatrix.hlsl
@@ -57,12 +57,9 @@ bool2x2 fn2(bool V) {
// CHECK-NEXT: [[ENTRY:.*:]]
// CHECK-NEXT: [[RETVAL:%.*]] = alloca i1, align 4
// CHECK-NEXT: [[S:%.*]] = alloca [[STRUCT_S:%.*]], align 1
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 1 [[S]], ptr align 1 @__const._Z3fn3v.s, i32 20, i1 false)
// CHECK-NEXT: [[BM:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 0
-// CHECK-NEXT: store <4 x i32> <i32 1, i32 0, i32 1, i32 0>, ptr [[BM]], align 1
-// CHECK-NEXT: [[F:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 1
-// CHECK-NEXT: store float 1.000000e+00, ptr [[F]], align 1
-// CHECK-NEXT: [[BM1:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 0
-// CHECK-NEXT: [[TMP0:%.*]] = load <4 x i32>, ptr [[BM1]], align 1
+// CHECK-NEXT: [[TMP0:%.*]] = load <4 x i32>, ptr [[BM]], align 1
// CHECK-NEXT: [[MATRIXEXT:%.*]] = extractelement <4 x i32> [[TMP0]], i32 0
// CHECK-NEXT: store i32 [[MATRIXEXT]], ptr [[RETVAL]], align 4
// CHECK-NEXT: [[TMP1:%.*]] = load i1, ptr [[RETVAL]], align 4
@@ -113,15 +110,12 @@ void fn5() {
// CHECK-NEXT: [[V:%.*]] = alloca i32, align 4
// CHECK-NEXT: [[S:%.*]] = alloca [[STRUCT_S:%.*]], align 1
// CHECK-NEXT: store i32 0, ptr [[V]], align 4
-// CHECK-NEXT: [[BM:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 0
-// CHECK-NEXT: store <4 x i32> <i32 1, i32 0, i32 1, i32 0>, ptr [[BM]], align 1
-// CHECK-NEXT: [[F:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 1
-// CHECK-NEXT: store float 1.000000e+00, ptr [[F]], align 1
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 1 [[S]], ptr align 1 @__const._Z3fn6v.s, i32 20, i1 false)
// CHECK-NEXT: [[TMP0:%.*]] = load i32, ptr [[V]], align 4
// CHECK-NEXT: [[LOADEDV:%.*]] = trunc i32 [[TMP0]] to i1
-// CHECK-NEXT: [[BM1:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 0
+// CHECK-NEXT: [[BM:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 0
// CHECK-NEXT: [[TMP1:%.*]] = zext i1 [[LOADEDV]] to i32
-// CHECK-NEXT: [[TMP2:%.*]] = getelementptr <4 x i32>, ptr [[BM1]], i32 0, i32 1
+// CHECK-NEXT: [[TMP2:%.*]] = getelementptr <4 x i32>, ptr [[BM]], i32 0, i32 1
// CHECK-NEXT: store i32 [[TMP1]], ptr [[TMP2]], align 4
// CHECK-NEXT: ret void
//
>From 4fe41dfc5b48d5cefcfb4b12e8030e36202b1c6c Mon Sep 17 00:00:00 2001
From: Deric Cheung <cheung.deric at gmail.com>
Date: Thu, 29 Jan 2026 14:50:54 -0800
Subject: [PATCH 4/4] Replace undef with poison
---
clang/lib/CodeGen/CGExprConstant.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/clang/lib/CodeGen/CGExprConstant.cpp b/clang/lib/CodeGen/CGExprConstant.cpp
index 580c95279aaa9..e42dc6d9e44fe 100644
--- a/clang/lib/CodeGen/CGExprConstant.cpp
+++ b/clang/lib/CodeGen/CGExprConstant.cpp
@@ -2532,7 +2532,7 @@ ConstantEmitter::tryEmitPrivate(const APValue &Value, QualType DestType,
Inits[Idx] =
llvm::ConstantFP::get(CGM.getLLVMContext(), Elt.getFloat());
else if (Elt.isIndeterminate())
- Inits[Idx] = llvm::UndefValue::get(
+ Inits[Idx] = llvm::PoisonValue::get(
CGM.getTypes().ConvertType(MT->getElementType()));
else
llvm_unreachable("unsupported matrix element type");
More information about the cfe-commits
mailing list