[clang] [HLSL] add support for HLSLAggregateSplatCast and HLSLElementwiseCast to constant expression evaluator (PR #164700)
Sarah Spall via cfe-commits
cfe-commits at lists.llvm.org
Tue Nov 4 13:11:06 PST 2025
https://github.com/spall updated https://github.com/llvm/llvm-project/pull/164700
>From 276fca41ad8e81ce4189266c20d260646d6d5f4c Mon Sep 17 00:00:00 2001
From: Sarah Spall <sarahspall at microsoft.com>
Date: Fri, 31 Jan 2025 16:57:24 -0800
Subject: [PATCH 1/6] add support for HLSLAggregateSplatCast and
HLSLElementwiseCast to Constant expression evaluator. Add tests. Fix/Add
support for other minor necessary things.
---
clang/lib/AST/ExprConstant.cpp | 587 +++++++++++++++++-
.../Types/AggregateSplatConstantExpr.hlsl | 89 +++
.../BuiltinVector/TruncationConstantExpr.hlsl | 21 +
.../Types/ElementwiseCastConstantExpr.hlsl | 76 +++
4 files changed, 772 insertions(+), 1 deletion(-)
create mode 100644 clang/test/SemaHLSL/Types/AggregateSplatConstantExpr.hlsl
create mode 100644 clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 00aaaab957591..5dfb2b3e3491f 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -3828,6 +3828,333 @@ static bool CheckArraySize(EvalInfo &Info, const ConstantArrayType *CAT,
/*Diag=*/true);
}
+static bool handleScalarCast(EvalInfo &Info, const FPOptions FPO, const Expr *E,
+ QualType SourceTy, QualType DestTy,
+ APValue const &Original, APValue &Result) {
+ // boolean must be checked before integer
+ // since IsIntegerType() is true for bool
+ if (SourceTy->isBooleanType()) {
+ if (DestTy->isBooleanType()) {
+ Result = Original;
+ return true;
+ }
+ if (DestTy->isIntegerType() || DestTy->isRealFloatingType()) {
+ bool BoolResult;
+ if (!HandleConversionToBool(Original, BoolResult))
+ return false;
+ uint64_t IntResult = BoolResult;
+ Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy));
+ // TODO destty is wrong here if destty is float....
+ // can we use sourcety here?
+ }
+ if (DestTy->isFloatingType()) {
+ APValue Result2 = APValue(APFloat(0.0));
+ if (!HandleIntToFloatCast(Info, E, FPO,
+ Info.Ctx.getIntTypeForBitwidth(64, true),
+ Result.getInt(), DestTy, Result2.getFloat()))
+ return false;
+ Result = Result2;
+ }
+ return true;
+ }
+ if (SourceTy->isIntegerType()) {
+ if (DestTy->isRealFloatingType()) {
+ Result = APValue(APFloat(0.0));
+ return HandleIntToFloatCast(Info, E, FPO, SourceTy, Original.getInt(),
+ DestTy, Result.getFloat());
+ }
+ if (DestTy->isBooleanType()) {
+ bool BoolResult;
+ if (!HandleConversionToBool(Original, BoolResult))
+ return false;
+ uint64_t IntResult = BoolResult;
+ Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy));
+ return true;
+ }
+ if (DestTy->isIntegerType()) {
+ Result = APValue(
+ HandleIntToIntCast(Info, E, DestTy, SourceTy, Original.getInt()));
+ return true;
+ }
+ } else if (SourceTy->isRealFloatingType()) {
+ if (DestTy->isRealFloatingType()) {
+ Result = Original;
+ return HandleFloatToFloatCast(Info, E, SourceTy, DestTy,
+ Result.getFloat());
+ }
+ if (DestTy->isBooleanType()) {
+ bool BoolResult;
+ if (!HandleConversionToBool(Original, BoolResult))
+ return false;
+ uint64_t IntResult = BoolResult;
+ Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy));
+ return true;
+ }
+ if (DestTy->isIntegerType()) {
+ Result = APValue(APSInt());
+ return HandleFloatToIntCast(Info, E, SourceTy, Original.getFloat(),
+ DestTy, Result.getInt());
+ }
+ }
+
+ // Info.FFDiag(E, diag::err_convertvector_constexpr_unsupported_vector_cast)
+ // << SourceTy << DestTy;
+ return false;
+}
+
+// do the heavy lifting for casting to aggregate types
+// because we have to deal with bitfields specially
+static bool constructAggregate(EvalInfo &Info, const FPOptions FPO,
+ const Expr *E, APValue &Result,
+ QualType ResultType,
+ SmallVectorImpl<APValue> &Elements,
+ SmallVectorImpl<QualType> &ElTypes) {
+
+ SmallVector<std::tuple<APValue *, QualType, unsigned>> WorkList = {
+ {&Result, ResultType, 0}};
+
+ unsigned ElI = 0;
+ while (!WorkList.empty() && ElI < Elements.size()) {
+ auto [Res, Type, BitWidth] = WorkList.pop_back_val();
+
+ if (Type->isRealFloatingType() || Type->isBooleanType()) {
+ if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], Type, Elements[ElI],
+ *Res))
+ return false;
+ ElI++;
+ continue;
+ }
+ if (Type->isIntegerType()) {
+ if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], Type, Elements[ElI],
+ *Res))
+ return false;
+ if (BitWidth > 0) {
+ if (!Res->isInt())
+ return false;
+ APSInt &Int = Res->getInt();
+ unsigned OldBitWidth = Int.getBitWidth();
+ unsigned NewBitWidth = BitWidth;
+ if (NewBitWidth < OldBitWidth)
+ Int = Int.trunc(NewBitWidth).extend(OldBitWidth);
+ }
+ ElI++;
+ continue;
+ }
+ if (Type->isVectorType()) {
+ QualType ElTy = Type->castAs<VectorType>()->getElementType();
+ unsigned NumEl = Type->castAs<VectorType>()->getNumElements();
+ SmallVector<APValue> Vals(NumEl);
+ for (unsigned I = 0; I < NumEl; ++I) {
+ if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], ElTy, Elements[ElI],
+ Vals[I]))
+ return false;
+ ElI++;
+ }
+ *Res = APValue(Vals.data(), NumEl);
+ continue;
+ }
+ if (Type->isConstantArrayType()) {
+ QualType ElTy = cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))
+ ->getElementType();
+ uint64_t Size =
+ cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))->getZExtSize();
+ *Res = APValue(APValue::UninitArray(), Size, Size);
+ for (int64_t I = Size - 1; I > -1; --I) {
+ WorkList.emplace_back(&Res->getArrayInitializedElt(I), ElTy, 0u);
+ }
+ continue;
+ }
+ if (Type->isRecordType()) {
+ const RecordDecl *RD = Type->getAsRecordDecl();
+
+ unsigned NumBases = 0;
+ if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD))
+ NumBases = CXXRD->getNumBases();
+
+ *Res = APValue(APValue::UninitStruct(), NumBases,
+ std::distance(RD->field_begin(), RD->field_end()));
+
+ SmallVector<std::tuple<APValue *, QualType, unsigned>> ReverseList;
+ // we need to traverse backwards
+ // Visit the base classes.
+ if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+ // todo assert there is only 1 base at most
+ for (size_t I = 0, E = CXXRD->getNumBases(); I != E; ++I) {
+ const CXXBaseSpecifier &BS = CXXRD->bases_begin()[I];
+ ReverseList.emplace_back(&Res->getStructBase(I), BS.getType(), 0u);
+ }
+ }
+
+ // Visit the fields.
+ for (FieldDecl *FD : RD->fields()) {
+ unsigned FDBW = 0;
+ if (FD->isUnnamedBitField())
+ continue;
+ if (FD->isBitField()) {
+ FDBW = FD->getBitWidthValue();
+ }
+
+ ReverseList.emplace_back(&Res->getStructField(FD->getFieldIndex()),
+ FD->getType(), FDBW);
+ }
+
+ std::reverse(ReverseList.begin(), ReverseList.end());
+ llvm::append_range(WorkList, ReverseList);
+ continue;
+ }
+ return false;
+ }
+ return true;
+}
+
+static bool handleElementwiseCast(EvalInfo &Info, const Expr *E,
+ const FPOptions FPO,
+ SmallVectorImpl<APValue> &Elements,
+ SmallVectorImpl<QualType> &SrcTypes,
+ SmallVectorImpl<QualType> &DestTypes,
+ SmallVectorImpl<APValue> &Results) {
+
+ assert((Elements.size() == SrcTypes.size()) &&
+ (Elements.size() == DestTypes.size()));
+
+ for (unsigned I = 0, ESz = Elements.size(); I < ESz; ++I) {
+ APValue Original = Elements[I];
+ QualType SourceTy = SrcTypes[I];
+ QualType DestTy = DestTypes[I];
+
+ if (!handleScalarCast(Info, FPO, E, SourceTy, DestTy, Original, Results[I]))
+ return false;
+ }
+ return true;
+}
+
+static unsigned elementwiseSize(EvalInfo &Info, QualType BaseTy) {
+
+ SmallVector<QualType> WorkList = {BaseTy};
+
+ unsigned Size = 0;
+ while (!WorkList.empty()) {
+ QualType Type = WorkList.pop_back_val();
+ if (Type->isRealFloatingType() || Type->isIntegerType() ||
+ Type->isBooleanType()) {
+ ++Size;
+ continue;
+ }
+ if (Type->isVectorType()) {
+ unsigned NumEl = Type->castAs<VectorType>()->getNumElements();
+ Size += NumEl;
+ continue;
+ }
+ if (Type->isConstantArrayType()) {
+ QualType ElTy = cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))
+ ->getElementType();
+ uint64_t Size =
+ cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))->getZExtSize();
+ for (uint64_t I = 0; I < Size; ++I) {
+ WorkList.push_back(ElTy);
+ }
+ continue;
+ }
+ if (Type->isRecordType()) {
+ const RecordDecl *RD = Type->getAsRecordDecl();
+ // const ASTRecordLayout &Layout = Info.Ctx.getASTRecordLayout(RD);
+
+ // Visit the base classes.
+ if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+ // todo assert there is only 1 base at most
+ for (size_t I = 0, E = CXXRD->getNumBases(); I != E; ++I) {
+ const CXXBaseSpecifier &BS = CXXRD->bases_begin()[I];
+ WorkList.push_back(BS.getType());
+ }
+ }
+
+ // visit the fields.
+ for (FieldDecl *FD : RD->fields()) {
+ if (FD->isUnnamedBitField())
+ continue;
+ WorkList.push_back(FD->getType());
+ }
+ continue;
+ }
+ }
+ return Size;
+}
+
+static bool flattenAPValue(const ASTContext &Ctx, APValue Value,
+ QualType BaseTy, SmallVectorImpl<APValue> &Elements,
+ SmallVectorImpl<QualType> &Types, unsigned Size) {
+
+ SmallVector<std::pair<APValue, QualType>> WorkList = {{Value, BaseTy}};
+ unsigned Populated = 0;
+ while (!WorkList.empty() && Populated < Size) {
+ auto [Work, Type] = WorkList.pop_back_val();
+
+ if (Work.isFloat() || Work.isInt()) { // todo what does this do with bool
+ Elements.push_back(Work);
+ Types.push_back(Type);
+ Populated++;
+ continue;
+ }
+ if (Work.isVector()) {
+ assert(Type->isVectorType() && "Type mismatch.");
+ QualType ElTy = Type->castAs<VectorType>()->getElementType();
+ for (unsigned I = 0; I < Work.getVectorLength() && Populated < Size;
+ I++) {
+ Elements.push_back(Work.getVectorElt(I));
+ Types.push_back(ElTy);
+ Populated++;
+ }
+ continue;
+ }
+ if (Work.isArray()) {
+ assert(Type->isConstantArrayType() && "Type mismatch.");
+ QualType ElTy =
+ cast<ConstantArrayType>(Ctx.getAsArrayType(Type))->getElementType();
+ for (int64_t I = Work.getArraySize() - 1; I > -1; --I) {
+ WorkList.emplace_back(Work.getArrayInitializedElt(I), ElTy);
+ }
+ continue;
+ }
+
+ if (Work.isStruct()) {
+ assert(Type->isRecordType() && "Type mismatch.");
+
+ const RecordDecl *RD = Type->getAsRecordDecl();
+
+ SmallVector<std::pair<APValue, QualType>> ReverseList;
+ // Visit the fields.
+ for (FieldDecl *FD : RD->fields()) {
+ if (FD->isUnnamedBitField())
+ continue;
+ // if (FD->isBitField()) {
+ ReverseList.emplace_back(Work.getStructField(FD->getFieldIndex()),
+ FD->getType());
+ }
+
+ std::reverse(ReverseList.begin(), ReverseList.end());
+ llvm::append_range(WorkList, ReverseList);
+
+ // Visit the base classes.
+ if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+ if (CXXRD->getNumBases() > 0) {
+ assert(CXXRD->getNumBases() == 1);
+ const CXXBaseSpecifier &BS = CXXRD->bases_begin()[0];
+ const APValue &Base = Work.getStructBase(0);
+
+ // Can happen in error cases.
+ if (!Base.isStruct())
+ return false;
+
+ WorkList.emplace_back(Base, BS.getType());
+ }
+ }
+ continue;
+ }
+ return false;
+ }
+ return true;
+}
+
namespace {
/// A handle to a complete object (an object that is not a subobject of
/// another object).
@@ -8666,6 +8993,25 @@ class ExprEvaluatorBase
case CK_UserDefinedConversion:
return StmtVisitorTy::Visit(E->getSubExpr());
+ case CK_HLSLArrayRValue: {
+ const Expr *SubExpr = E->getSubExpr();
+ if (!SubExpr->isGLValue()) {
+ APValue Val;
+ if (!Evaluate(Val, Info, SubExpr))
+ return false;
+ return DerivedSuccess(Val, E);
+ }
+
+ LValue LVal;
+ if (!EvaluateLValue(SubExpr, LVal, Info))
+ return false;
+ APValue RVal;
+ // Note, we use the subexpression's type in order to retain cv-qualifiers.
+ if (!handleLValueToRValueConversion(Info, E, SubExpr->getType(), LVal,
+ RVal))
+ return false;
+ return DerivedSuccess(RVal, E);
+ }
case CK_LValueToRValue: {
LValue LVal;
if (!EvaluateLValue(E->getSubExpr(), LVal, Info))
@@ -10850,6 +11196,67 @@ bool RecordExprEvaluator::VisitCastExpr(const CastExpr *E) {
Result = *Value;
return true;
}
+ case CK_HLSLAggregateSplatCast: {
+ APValue Val;
+ const Expr *SE = E->getSubExpr();
+
+ if (!Evaluate(Val, Info, SE))
+ return Error(E);
+
+ unsigned NEls = elementwiseSize(Info, E->getType());
+ // flatten the source
+ SmallVector<APValue, 1> SrcEls;
+ SmallVector<QualType, 1> SrcTypes;
+ if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes, NEls))
+ return Error(E);
+
+ // check there is only one and splat it
+ assert(SrcEls.size() == 1);
+ SmallVector<APValue> SplatEls(NEls, SrcEls[0]);
+ SmallVector<QualType> SplatType(NEls, SrcTypes[0]);
+
+ APValue Tmp;
+ handleDefaultInitValue(E->getType(), Tmp);
+
+ // cast the elements and construct our struct result
+ const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+ if (!constructAggregate(Info, FPO, E, Result, E->getType(), SplatEls,
+ SplatType))
+ return Error(E);
+
+ return true;
+ }
+ case CK_HLSLElementwiseCast: {
+ APValue Val;
+ const Expr *SE = E->getSubExpr();
+
+ if (!Evaluate(Val, Info, SE))
+ return Error(E);
+
+ // must be dealing with a record;
+ if (Val.isLValue()) {
+ LValue LVal;
+ LVal.setFrom(Info.Ctx, Val);
+ if (!handleLValueToRValueConversion(Info, SE, SE->getType(), LVal, Val))
+ return false;
+ }
+
+ // flatten the source
+ SmallVector<APValue> SrcEls;
+ SmallVector<QualType> SrcTypes;
+ if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes,
+ UINT_MAX))
+ return Error(E);
+
+ // cast the elements and construct our struct result
+ const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+
+ if (!constructAggregate(Info, FPO, E, Result, E->getType(), SrcEls,
+ SrcTypes))
+ return Error(E);
+
+ return true;
+ }
}
}
@@ -11345,6 +11752,58 @@ bool VectorExprEvaluator::VisitCastExpr(const CastExpr *E) {
Elements.push_back(Val.getVectorElt(I));
return Success(Elements, E);
}
+ case CK_HLSLAggregateSplatCast: {
+ APValue Val;
+
+ if (!Evaluate(Val, Info, SE))
+ return Error(E);
+
+ // this cast doesn't handle splatting from scalars when result is a vector
+ SmallVector<APValue, 1> Elements;
+ SmallVector<QualType, 1> DestTypes = {VTy->getElementType()};
+ SmallVector<QualType, 1> SrcTypes;
+ if (!flattenAPValue(Info.Ctx, Val, SETy, Elements, SrcTypes, NElts))
+ return Error(E);
+
+ // check there is only one element and cast and splat it
+ assert(Elements.size() == 1 &&
+ "HLSLAggregateSplatCast RHS must contain one element");
+ const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+ SmallVector<APValue, 1> ResultEls(1);
+ if (!handleElementwiseCast(Info, E, FPO, Elements, SrcTypes, DestTypes,
+ ResultEls))
+ return Error(E);
+
+ SmallVector<APValue, 4> SplatEls(NElts, ResultEls[0]);
+ return Success(SplatEls, E);
+ }
+ case CK_HLSLElementwiseCast: {
+ APValue Val;
+
+ if (!Evaluate(Val, Info, SE))
+ return Error(E);
+
+ // must be dealing with a record;
+ if (Val.isLValue()) {
+ LValue LVal;
+ LVal.setFrom(Info.Ctx, Val);
+ if (!handleLValueToRValueConversion(Info, SE, SE->getType(), LVal, Val))
+ return false;
+ }
+
+ SmallVector<APValue, 4> Elements;
+ SmallVector<QualType, 4> DestTypes(NElts, VTy->getElementType());
+ SmallVector<QualType, 4> SrcTypes;
+ if (!flattenAPValue(Info.Ctx, Val, SETy, Elements, SrcTypes, NElts))
+ return Error(E);
+ // cast elements
+ const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+ SmallVector<APValue, 4> ResultEls(NElts);
+ if (!handleElementwiseCast(Info, E, FPO, Elements, SrcTypes, DestTypes,
+ ResultEls))
+ return Error(E);
+ return Success(ResultEls, E);
+ }
default:
return ExprEvaluatorBaseTy::VisitCastExpr(E);
}
@@ -13029,6 +13488,7 @@ namespace {
bool VisitCallExpr(const CallExpr *E) {
return handleCallExpr(E, Result, &This);
}
+ bool VisitCastExpr(const CastExpr *E);
bool VisitInitListExpr(const InitListExpr *E,
QualType AllocType = QualType());
bool VisitArrayInitLoopExpr(const ArrayInitLoopExpr *E);
@@ -13099,6 +13559,70 @@ static bool MaybeElementDependentArrayFiller(const Expr *FillerExpr) {
return true;
}
+bool ArrayExprEvaluator::VisitCastExpr(const CastExpr *E) {
+ const Expr *SE = E->getSubExpr();
+
+ switch (E->getCastKind()) {
+ default:
+ return ExprEvaluatorBaseTy::VisitCastExpr(E);
+ case CK_HLSLAggregateSplatCast: {
+ APValue Val;
+
+ if (!Evaluate(Val, Info, SE))
+ return Error(E);
+
+ unsigned NEls = elementwiseSize(Info, E->getType());
+ // flatten the source
+ SmallVector<APValue, 1> SrcEls;
+ SmallVector<QualType, 1> SrcTypes;
+ if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes, NEls))
+ return Error(E);
+
+ // check there is only one and splat it
+ assert(SrcEls.size() == 1);
+ SmallVector<APValue> SplatEls(NEls, SrcEls[0]);
+ SmallVector<QualType> SplatType(NEls, SrcTypes[0]);
+
+ // cast the elements
+ const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+ if (!constructAggregate(Info, FPO, E, Result, E->getType(), SplatEls,
+ SplatType))
+ return Error(E);
+
+ return true;
+ }
+ case CK_HLSLElementwiseCast: {
+ APValue Val;
+
+ if (!Evaluate(Val, Info, SE))
+ return Error(E);
+
+ // must be dealing with a record;
+ if (Val.isLValue()) {
+ LValue LVal;
+ LVal.setFrom(Info.Ctx, Val);
+ if (!handleLValueToRValueConversion(Info, SE, SE->getType(), LVal, Val))
+ return false;
+ }
+
+ // flatten the source
+ SmallVector<APValue> SrcEls;
+ SmallVector<QualType> SrcTypes;
+ if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes,
+ UINT_MAX))
+ return Error(E);
+
+ // cast the elements
+ const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+ if (!constructAggregate(Info, FPO, E, Result, E->getType(), SrcEls,
+ SrcTypes))
+ return Error(E);
+
+ return true;
+ }
+ }
+}
+
bool ArrayExprEvaluator::VisitInitListExpr(const InitListExpr *E,
QualType AllocType) {
const ConstantArrayType *CAT = Info.Ctx.getAsConstantArrayType(
@@ -16801,7 +17325,6 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
case CK_NoOp:
case CK_LValueToRValueBitCast:
case CK_HLSLArrayRValue:
- case CK_HLSLElementwiseCast:
return ExprEvaluatorBaseTy::VisitCastExpr(E);
case CK_MemberPointerToBoolean:
@@ -16948,6 +17471,35 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
return Error(E);
return Success(Val.getVectorElt(0), E);
}
+ case CK_HLSLElementwiseCast: {
+ APValue Val;
+
+ if (!Evaluate(Val, Info, SubExpr))
+ return Error(E);
+
+ // must be dealing with a record;
+ if (Val.isLValue()) {
+ LValue LVal;
+ LVal.setFrom(Info.Ctx, Val);
+ if (!handleLValueToRValueConversion(Info, SubExpr, SubExpr->getType(),
+ LVal, Val))
+ return false;
+ }
+
+ SmallVector<APValue, 1> Elements;
+ SmallVector<QualType, 1> DestTypes(1, DestType);
+ SmallVector<QualType, 1> SrcTypes;
+ if (!flattenAPValue(Info.Ctx, Val, SrcType, Elements, SrcTypes, 1))
+ return Error(E);
+
+ // cast our single element
+ const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+ APValue ResultVal;
+ if (!handleScalarCast(Info, FPO, E, SrcTypes[0], DestTypes[0], Elements[0],
+ ResultVal))
+ return Error(E);
+ return Success(ResultVal, E);
+ }
}
llvm_unreachable("unknown cast resulting in integral value");
@@ -17485,6 +18037,9 @@ bool FloatExprEvaluator::VisitCastExpr(const CastExpr *E) {
default:
return ExprEvaluatorBaseTy::VisitCastExpr(E);
+ case CK_HLSLAggregateSplatCast:
+ llvm_unreachable("invalid cast kind for floating value");
+
case CK_IntegralToFloating: {
APSInt IntResult;
const FPOptions FPO = E->getFPFeaturesInEffect(
@@ -17523,6 +18078,36 @@ bool FloatExprEvaluator::VisitCastExpr(const CastExpr *E) {
return Error(E);
return Success(Val.getVectorElt(0), E);
}
+ case CK_HLSLElementwiseCast: {
+ APValue Val;
+
+ if (!Evaluate(Val, Info, SubExpr))
+ return Error(E);
+
+ // must be dealing with a record;
+ if (Val.isLValue()) {
+ LValue LVal;
+ LVal.setFrom(Info.Ctx, Val);
+ if (!handleLValueToRValueConversion(Info, SubExpr, SubExpr->getType(),
+ LVal, Val))
+ return false;
+ }
+
+ SmallVector<APValue, 1> Elements;
+ SmallVector<QualType, 1> DestTypes(1, E->getType());
+ SmallVector<QualType, 1> SrcTypes;
+ if (!flattenAPValue(Info.Ctx, Val, SubExpr->getType(), Elements, SrcTypes,
+ 1))
+ return Error(E);
+
+ // cast our single element
+ const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+ APValue ResultVal;
+ if (!handleScalarCast(Info, FPO, E, SrcTypes[0], DestTypes[0], Elements[0],
+ ResultVal))
+ return Error(E);
+ return Success(ResultVal, E);
+ }
}
}
diff --git a/clang/test/SemaHLSL/Types/AggregateSplatConstantExpr.hlsl b/clang/test/SemaHLSL/Types/AggregateSplatConstantExpr.hlsl
new file mode 100644
index 0000000000000..7df41f24ee0d9
--- /dev/null
+++ b/clang/test/SemaHLSL/Types/AggregateSplatConstantExpr.hlsl
@@ -0,0 +1,89 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -fnative-half-type -std=hlsl202x -verify %s
+
+// expected-no-diagnostics
+
+struct Base {
+ double D;
+ uint64_t2 U;
+ int16_t I : 5;
+ uint16_t I2: 5;
+};
+
+struct R : Base {
+ int G : 10;
+ int : 30;
+ float F;
+};
+
+struct B1 {
+ float A;
+ float B;
+};
+
+struct B2 : B1 {
+ int C;
+ int D;
+ bool BB;
+};
+
+// tests for HLSLAggregateSplatCast
+export void fn() {
+ // result type vector
+ // splat from a vector of size 1
+
+ constexpr float1 Y = {1.0};
+ constexpr int1 A1 = {1};
+ constexpr float4 F4 = (float4)Y;
+ _Static_assert(F4[0] == 1.0, "Woo!");
+ _Static_assert(F4[1] == 1.0, "Woo!");
+ _Static_assert(F4[2] == 1.0, "Woo!");
+ _Static_assert(F4[3] == 1.0, "Woo!");
+
+ // result type array
+ // splat from a scalar
+ constexpr float F = 3.33;
+ constexpr int B6[6] = (int[6])F;
+ _Static_assert(B6[0] == 3, "Woo!");
+ _Static_assert(B6[1] == 3, "Woo!");
+ _Static_assert(B6[2] == 3, "Woo!");
+ _Static_assert(B6[3] == 3, "Woo!");
+ _Static_assert(B6[4] == 3, "Woo!");
+ _Static_assert(B6[5] == 3, "Woo!");
+
+ // splat from a vector of size 1
+ constexpr uint64_t2 A7[2] = (uint64_t2[2])A1;
+ _Static_assert(A7[0][0] == 1, "Woo!");
+ _Static_assert(A7[0][1] == 1, "Woo!");
+ _Static_assert(A7[1][0] == 1, "Woo!");
+ _Static_assert(A7[1][1] == 1, "Woo!");
+
+ // result type struct
+ // splat from a scalar
+ constexpr double D = 100.6789;
+ constexpr R SR = (R)D;
+ _Static_assert(SR.D == 100.6789, "Woo!");
+ _Static_assert(SR.U[0] == 100, "Woo!");
+ _Static_assert(SR.U[1] == 100, "Woo!");
+ _Static_assert(SR.I == 4, "Woo!");
+ _Static_assert(SR.I2 == 4, "Woo!");
+ _Static_assert(SR.G == 100, "Woo!");
+ _Static_assert(SR.F == 100.6789, "Woo!");
+
+ // splat from a vector of size 1
+ constexpr float1 A100 = {1000.1111};
+ constexpr B2 SB2 = (B2)A100;
+ _Static_assert(SB2.A == 1000.1111, "Woo!");
+ _Static_assert(SB2.B == 1000.1111, "Woo!");
+ _Static_assert(SB2.C == 1000, "Woo!");
+ _Static_assert(SB2.D == 1000, "Woo!");
+ _Static_assert(SB2.BB == true, "Woo!");
+
+ // splat from a bool to an int and float etc
+ constexpr bool B = true;
+ constexpr B2 SB3 = (B2)B;
+ _Static_assert(SB3.A == 1.0, "Woo!");
+ _Static_assert(SB3.B == 1.0, "Woo!");
+ _Static_assert(SB3.C == 1, "Woo!");
+ _Static_assert(SB3.D == 1, "Woo!");
+ _Static_assert(SB3.BB == true, "Woo!");
+}
diff --git a/clang/test/SemaHLSL/Types/BuiltinVector/TruncationConstantExpr.hlsl b/clang/test/SemaHLSL/Types/BuiltinVector/TruncationConstantExpr.hlsl
index 918daa03d8032..85b7271fd2c8c 100644
--- a/clang/test/SemaHLSL/Types/BuiltinVector/TruncationConstantExpr.hlsl
+++ b/clang/test/SemaHLSL/Types/BuiltinVector/TruncationConstantExpr.hlsl
@@ -5,7 +5,28 @@
// Note: these tests are a bit awkward because at time of writing we don't have a
// good way to constexpr `any` for bool vector conditions, and the condition for
// _Static_assert must be an integral constant.
+
+struct S {
+ int3 A;
+ float B;
+};
+
export void fn() {
+
+ _Static_assert(((float4)(int[6]){1,2,3,4,5,6}).x == 1.0, "Woo!");
+
+ // This compiling successfully verifies that the array constant expression
+ // gets truncated to a float at compile time for instantiation via the
+ // flat cast
+ _Static_assert(((int)(int[2]){1,2}) == 1, "Woo!");
+
+ // This compiling successfully verifies that the struct constant expression
+ // gets truncated to an integer at compile time for instatiation via the
+ // flat cast
+ _Static_assert(((int)(S){{1,2,3},1.0}) == 1, "Woo!");
+
+ _Static_assert(((float)(float[2]){1.0,2.0}) == 1.0, "Woo!");
+
// This compiling successfully verifies that the vector constant expression
// gets truncated to an integer at compile time for instantiation.
_Static_assert(((int)1.xxxx) + 0 == 1, "Woo!");
diff --git a/clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl b/clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl
new file mode 100644
index 0000000000000..6f697e8097a21
--- /dev/null
+++ b/clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl
@@ -0,0 +1,76 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -fnative-half-type -std=hlsl202x -verify %s
+
+// expected-no-diagnostics
+
+struct Base {
+ double D;
+ uint64_t2 U;
+ int16_t I : 5;
+ uint16_t I2: 5;
+};
+
+struct R : Base {
+ int G : 10;
+ int : 30;
+ float F;
+};
+
+struct B1 {
+ float A;
+ float B;
+};
+
+struct B2 : B1 {
+ int C;
+ int D;
+ bool BB;
+};
+
+export void fn() {
+
+ // truncation tests
+ // result type int
+ // truncate from struct
+ constexpr B1 SB1 = {1.0, 3.0};
+ constexpr float Blah = SB1.A;
+ constexpr int X = (int)SB1;
+ _Static_assert(X == 1, "Woo!");
+
+ // result type float
+ // truncate from array
+ constexpr B1 Arr[2] = {4.0, 3.0, 2.0, 1.0};
+ constexpr float F = (float)Arr;
+ _Static_assert(F == 4.0, "Woo!");
+
+ // result type vector
+ // truncate from array of vector
+ constexpr int2 Arr2[2] = {5,6,7,8};
+ constexpr int2 I2 = (int2)Arr2;
+ _Static_assert(I2[0] == 5, "Woo!");
+ _Static_assert(I2[1] == 6, "Woo!");
+
+ // lhs and rhs are same "size" tests
+
+ // result type vector from array
+ constexpr int4 I4 = (int4)Arr;
+ _Static_assert(I4[0] == 4, "Woo!");
+ _Static_assert(I4[1] == 3, "Woo!");
+ _Static_assert(I4[2] == 2, "Woo!");
+ _Static_assert(I4[3] == 1, "Woo!");
+
+ // result type array from vector
+ constexpr double3 D3 = {100.11, 200.11, 300.11};
+ constexpr float FArr[3] = (float[3])D3;
+ _Static_assert(FArr[0] == 100.11, "Woo!");
+ _Static_assert(FArr[1] == 200.11, "Woo!");
+ _Static_assert(FArr[2] == 300.11, "Woo!");
+
+ // result type struct from struct
+ constexpr B2 SB2 = {5.5, 6.5, 1000, 5000, false};
+ constexpr Base SB = (Base)SB2;
+ _Static_assert(SB.D == 5.5, "Woo!");
+ _Static_assert(SB.U[0] == 6, "Woo!");
+ _Static_assert(SB.U[1] == 1000, "Woo!");
+ _Static_assert(SB.I == 8, "Woo!");
+ _Static_assert(SB.I2 == 0, "Woo!");
+}
>From 77abb6bceebe5d98534b1298879e52a3dbd02bea Mon Sep 17 00:00:00 2001
From: Sarah Spall <sarahspall at microsoft.com>
Date: Wed, 22 Oct 2025 14:51:12 -0700
Subject: [PATCH 2/6] first self review
---
clang/lib/AST/ExprConstant.cpp | 25 ++++++++-----------
.../BuiltinVector/TruncationConstantExpr.hlsl | 19 --------------
.../Types/ElementwiseCastConstantExpr.hlsl | 16 ++++++++++++
3 files changed, 26 insertions(+), 34 deletions(-)
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 5dfb2b3e3491f..f9a51bfc94751 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -3843,9 +3843,8 @@ static bool handleScalarCast(EvalInfo &Info, const FPOptions FPO, const Expr *E,
if (!HandleConversionToBool(Original, BoolResult))
return false;
uint64_t IntResult = BoolResult;
- Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy));
- // TODO destty is wrong here if destty is float....
- // can we use sourcety here?
+ Result = APValue(Info.Ctx.MakeIntValue(
+ IntResult, Info.Ctx.getIntTypeForBitwidth(64, true)));
}
if (DestTy->isFloatingType()) {
APValue Result2 = APValue(APFloat(0.0));
@@ -3897,8 +3896,6 @@ static bool handleScalarCast(EvalInfo &Info, const FPOptions FPO, const Expr *E,
}
}
- // Info.FFDiag(E, diag::err_convertvector_constexpr_unsupported_vector_cast)
- // << SourceTy << DestTy;
return false;
}
@@ -3917,7 +3914,7 @@ static bool constructAggregate(EvalInfo &Info, const FPOptions FPO,
while (!WorkList.empty() && ElI < Elements.size()) {
auto [Res, Type, BitWidth] = WorkList.pop_back_val();
- if (Type->isRealFloatingType() || Type->isBooleanType()) {
+ if (Type->isRealFloatingType()) {
if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], Type, Elements[ElI],
*Res))
return false;
@@ -3978,10 +3975,10 @@ static bool constructAggregate(EvalInfo &Info, const FPOptions FPO,
// we need to traverse backwards
// Visit the base classes.
if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
- // todo assert there is only 1 base at most
- for (size_t I = 0, E = CXXRD->getNumBases(); I != E; ++I) {
- const CXXBaseSpecifier &BS = CXXRD->bases_begin()[I];
- ReverseList.emplace_back(&Res->getStructBase(I), BS.getType(), 0u);
+ if (CXXRD->getNumBases() > 0) {
+ assert(CXXRD->getNumBases() == 1);
+ const CXXBaseSpecifier &BS = CXXRD->bases_begin()[0];
+ ReverseList.emplace_back(&Res->getStructBase(0), BS.getType(), 0u);
}
}
@@ -4057,13 +4054,12 @@ static unsigned elementwiseSize(EvalInfo &Info, QualType BaseTy) {
}
if (Type->isRecordType()) {
const RecordDecl *RD = Type->getAsRecordDecl();
- // const ASTRecordLayout &Layout = Info.Ctx.getASTRecordLayout(RD);
// Visit the base classes.
if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
- // todo assert there is only 1 base at most
- for (size_t I = 0, E = CXXRD->getNumBases(); I != E; ++I) {
- const CXXBaseSpecifier &BS = CXXRD->bases_begin()[I];
+ if (CXXRD->getNumBases() > 0) {
+ assert(CXXRD->getNumBases() == 1);
+ const CXXBaseSpecifier &BS = CXXRD->bases_begin()[0];
WorkList.push_back(BS.getType());
}
}
@@ -4126,7 +4122,6 @@ static bool flattenAPValue(const ASTContext &Ctx, APValue Value,
for (FieldDecl *FD : RD->fields()) {
if (FD->isUnnamedBitField())
continue;
- // if (FD->isBitField()) {
ReverseList.emplace_back(Work.getStructField(FD->getFieldIndex()),
FD->getType());
}
diff --git a/clang/test/SemaHLSL/Types/BuiltinVector/TruncationConstantExpr.hlsl b/clang/test/SemaHLSL/Types/BuiltinVector/TruncationConstantExpr.hlsl
index 85b7271fd2c8c..d5b59851c8da6 100644
--- a/clang/test/SemaHLSL/Types/BuiltinVector/TruncationConstantExpr.hlsl
+++ b/clang/test/SemaHLSL/Types/BuiltinVector/TruncationConstantExpr.hlsl
@@ -6,27 +6,8 @@
// good way to constexpr `any` for bool vector conditions, and the condition for
// _Static_assert must be an integral constant.
-struct S {
- int3 A;
- float B;
-};
-
export void fn() {
- _Static_assert(((float4)(int[6]){1,2,3,4,5,6}).x == 1.0, "Woo!");
-
- // This compiling successfully verifies that the array constant expression
- // gets truncated to a float at compile time for instantiation via the
- // flat cast
- _Static_assert(((int)(int[2]){1,2}) == 1, "Woo!");
-
- // This compiling successfully verifies that the struct constant expression
- // gets truncated to an integer at compile time for instatiation via the
- // flat cast
- _Static_assert(((int)(S){{1,2,3},1.0}) == 1, "Woo!");
-
- _Static_assert(((float)(float[2]){1.0,2.0}) == 1.0, "Woo!");
-
// This compiling successfully verifies that the vector constant expression
// gets truncated to an integer at compile time for instantiation.
_Static_assert(((int)1.xxxx) + 0 == 1, "Woo!");
diff --git a/clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl b/clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl
index 6f697e8097a21..01af59d7771bb 100644
--- a/clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl
+++ b/clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl
@@ -27,6 +27,13 @@ struct B2 : B1 {
};
export void fn() {
+/*
+ _Static_assert(((float4)(int[6]){1,2,3,4,5,6}).x == 1.0, "Woo!");
+
+ // This compiling successfully verifies that the array constant expression
+ // gets truncated to a float at compile time for instantiation via the
+ // flat cast
+ _Static_assert(((int)(int[2]){1,2}) == 1, "Woo!");
// truncation tests
// result type int
@@ -73,4 +80,13 @@ export void fn() {
_Static_assert(SB.U[1] == 1000, "Woo!");
_Static_assert(SB.I == 8, "Woo!");
_Static_assert(SB.I2 == 0, "Woo!");
+*/
+ // Make sure we read bitfields correctly
+ constexpr Base BB = {222.22, {100, 200}, -2, 7};
+ constexpr int Arr[5] = (int[5])BB;
+ _Static_assert(Arr[0] == 222, "Woo!");
+ _Static_assert(Arr[1] == 100, "Woo!");
+ _Static_assert(Arr[2] == 200, "Woo!");
+ _Static_assert(Arr[3] == -2, "Woo!");
+ _Static_assert(Arr[4] == 7, "Woo!");
}
>From 8a03b0df76f6ee73bdf7e04d815a960ccc05079d Mon Sep 17 00:00:00 2001
From: Sarah Spall <sarahspall at microsoft.com>
Date: Thu, 23 Oct 2025 10:01:09 -0700
Subject: [PATCH 3/6] fix issues introduced in self review
---
clang/lib/AST/ExprConstant.cpp | 8 +++++---
.../BuiltinVector/TruncationConstantExpr.hlsl | 2 --
.../Types/ElementwiseCastConstantExpr.hlsl | 15 +++++++--------
3 files changed, 12 insertions(+), 13 deletions(-)
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index b679f3455e160..98a3719ea7cfc 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -3843,13 +3843,15 @@ static bool handleScalarCast(EvalInfo &Info, const FPOptions FPO, const Expr *E,
if (!HandleConversionToBool(Original, BoolResult))
return false;
uint64_t IntResult = BoolResult;
- Result = APValue(Info.Ctx.MakeIntValue(
- IntResult, Info.Ctx.getIntTypeForBitwidth(64, true)));
+ QualType IntType = DestTy->isIntegerType()
+ ? DestTy
+ : Info.Ctx.getIntTypeForBitwidth(64, false);
+ Result = APValue(Info.Ctx.MakeIntValue(IntResult, IntType));
}
if (DestTy->isFloatingType()) {
APValue Result2 = APValue(APFloat(0.0));
if (!HandleIntToFloatCast(Info, E, FPO,
- Info.Ctx.getIntTypeForBitwidth(64, true),
+ Info.Ctx.getIntTypeForBitwidth(64, false),
Result.getInt(), DestTy, Result2.getFloat()))
return false;
Result = Result2;
diff --git a/clang/test/SemaHLSL/Types/BuiltinVector/TruncationConstantExpr.hlsl b/clang/test/SemaHLSL/Types/BuiltinVector/TruncationConstantExpr.hlsl
index d5b59851c8da6..918daa03d8032 100644
--- a/clang/test/SemaHLSL/Types/BuiltinVector/TruncationConstantExpr.hlsl
+++ b/clang/test/SemaHLSL/Types/BuiltinVector/TruncationConstantExpr.hlsl
@@ -5,9 +5,7 @@
// Note: these tests are a bit awkward because at time of writing we don't have a
// good way to constexpr `any` for bool vector conditions, and the condition for
// _Static_assert must be an integral constant.
-
export void fn() {
-
// This compiling successfully verifies that the vector constant expression
// gets truncated to an integer at compile time for instantiation.
_Static_assert(((int)1.xxxx) + 0 == 1, "Woo!");
diff --git a/clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl b/clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl
index 01af59d7771bb..1689fb091b624 100644
--- a/clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl
+++ b/clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl
@@ -27,7 +27,6 @@ struct B2 : B1 {
};
export void fn() {
-/*
_Static_assert(((float4)(int[6]){1,2,3,4,5,6}).x == 1.0, "Woo!");
// This compiling successfully verifies that the array constant expression
@@ -80,13 +79,13 @@ export void fn() {
_Static_assert(SB.U[1] == 1000, "Woo!");
_Static_assert(SB.I == 8, "Woo!");
_Static_assert(SB.I2 == 0, "Woo!");
-*/
+
// Make sure we read bitfields correctly
constexpr Base BB = {222.22, {100, 200}, -2, 7};
- constexpr int Arr[5] = (int[5])BB;
- _Static_assert(Arr[0] == 222, "Woo!");
- _Static_assert(Arr[1] == 100, "Woo!");
- _Static_assert(Arr[2] == 200, "Woo!");
- _Static_assert(Arr[3] == -2, "Woo!");
- _Static_assert(Arr[4] == 7, "Woo!");
+ constexpr int Arr3[5] = (int[5])BB;
+ _Static_assert(Arr3[0] == 222, "Woo!");
+ _Static_assert(Arr3[1] == 100, "Woo!");
+ _Static_assert(Arr3[2] == 200, "Woo!");
+ _Static_assert(Arr3[3] == -2, "Woo!");
+ _Static_assert(Arr3[4] == 7, "Woo!");
}
>From ee833d12bdad0a70613b8af13cec3f3786a88a6c Mon Sep 17 00:00:00 2001
From: Sarah Spall <sarahspall at microsoft.com>
Date: Thu, 30 Oct 2025 14:07:00 -0700
Subject: [PATCH 4/6] updates based on pr comments
---
clang/lib/AST/ExprConstant.cpp | 12 +++++++-----
.../SemaHLSL/Types/AggregateSplatConstantExpr.hlsl | 2 +-
.../SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl | 1 -
3 files changed, 8 insertions(+), 7 deletions(-)
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index a452210dea3ab..40c777aff35c7 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -3849,7 +3849,7 @@ static bool handleScalarCast(EvalInfo &Info, const FPOptions FPO, const Expr *E,
: Info.Ctx.getIntTypeForBitwidth(64, false);
Result = APValue(Info.Ctx.MakeIntValue(IntResult, IntType));
}
- if (DestTy->isFloatingType()) {
+ if (DestTy->isRealFloatingType()) {
APValue Result2 = APValue(APFloat(0.0));
if (!HandleIntToFloatCast(Info, E, FPO,
Info.Ctx.getIntTypeForBitwidth(64, false),
@@ -4048,9 +4048,9 @@ static unsigned elementwiseSize(EvalInfo &Info, QualType BaseTy) {
if (Type->isConstantArrayType()) {
QualType ElTy = cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))
->getElementType();
- uint64_t Size =
+ uint64_t ArrSize =
cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))->getZExtSize();
- for (uint64_t I = 0; I < Size; ++I) {
+ for (uint64_t I = 0; I < ArrSize; ++I) {
WorkList.push_back(ElTy);
}
continue;
@@ -11239,11 +11239,12 @@ bool RecordExprEvaluator::VisitCastExpr(const CastExpr *E) {
return false;
}
+ unsigned NEls = elementwiseSize(Info, E->getType());
// flatten the source
SmallVector<APValue> SrcEls;
SmallVector<QualType> SrcTypes;
if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes,
- UINT_MAX))
+ NEls))
return Error(E);
// cast the elements and construct our struct result
@@ -13617,11 +13618,12 @@ bool ArrayExprEvaluator::VisitCastExpr(const CastExpr *E) {
return false;
}
+ unsigned NEls = elementwiseSize(Info, E->getType());
// flatten the source
SmallVector<APValue> SrcEls;
SmallVector<QualType> SrcTypes;
if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes,
- UINT_MAX))
+ NEls))
return Error(E);
// cast the elements
diff --git a/clang/test/SemaHLSL/Types/AggregateSplatConstantExpr.hlsl b/clang/test/SemaHLSL/Types/AggregateSplatConstantExpr.hlsl
index 7df41f24ee0d9..82fcf30b03709 100644
--- a/clang/test/SemaHLSL/Types/AggregateSplatConstantExpr.hlsl
+++ b/clang/test/SemaHLSL/Types/AggregateSplatConstantExpr.hlsl
@@ -32,7 +32,6 @@ export void fn() {
// splat from a vector of size 1
constexpr float1 Y = {1.0};
- constexpr int1 A1 = {1};
constexpr float4 F4 = (float4)Y;
_Static_assert(F4[0] == 1.0, "Woo!");
_Static_assert(F4[1] == 1.0, "Woo!");
@@ -51,6 +50,7 @@ export void fn() {
_Static_assert(B6[5] == 3, "Woo!");
// splat from a vector of size 1
+ constexpr int1 A1 = {1};
constexpr uint64_t2 A7[2] = (uint64_t2[2])A1;
_Static_assert(A7[0][0] == 1, "Woo!");
_Static_assert(A7[0][1] == 1, "Woo!");
diff --git a/clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl b/clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl
index 1689fb091b624..756e74b399ebf 100644
--- a/clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl
+++ b/clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl
@@ -38,7 +38,6 @@ export void fn() {
// result type int
// truncate from struct
constexpr B1 SB1 = {1.0, 3.0};
- constexpr float Blah = SB1.A;
constexpr int X = (int)SB1;
_Static_assert(X == 1, "Woo!");
>From 43956578143e66113b2b3a01fcb64480f5cbd4cb Mon Sep 17 00:00:00 2001
From: Sarah Spall <sarahspall at microsoft.com>
Date: Thu, 30 Oct 2025 15:11:57 -0700
Subject: [PATCH 5/6] clang format
---
clang/lib/AST/ExprConstant.cpp | 6 ++----
1 file changed, 2 insertions(+), 4 deletions(-)
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 40c777aff35c7..6a097431514e0 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -11243,8 +11243,7 @@ bool RecordExprEvaluator::VisitCastExpr(const CastExpr *E) {
// flatten the source
SmallVector<APValue> SrcEls;
SmallVector<QualType> SrcTypes;
- if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes,
- NEls))
+ if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes, NEls))
return Error(E);
// cast the elements and construct our struct result
@@ -13622,8 +13621,7 @@ bool ArrayExprEvaluator::VisitCastExpr(const CastExpr *E) {
// flatten the source
SmallVector<APValue> SrcEls;
SmallVector<QualType> SrcTypes;
- if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes,
- NEls))
+ if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes, NEls))
return Error(E);
// cast the elements
>From 2b1e68ef8142f849f1ae8fba88e2fc65bc8b67b4 Mon Sep 17 00:00:00 2001
From: Sarah Spall <sarahspall at microsoft.com>
Date: Mon, 3 Nov 2025 11:25:54 -0800
Subject: [PATCH 6/6] make some suggested changes and make other changes
inspired by suggested changes
---
clang/lib/AST/ExprConstant.cpp | 254 +++++++-----------
.../Types/AggregateSplatConstantExpr.hlsl | 4 +-
2 files changed, 105 insertions(+), 153 deletions(-)
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 6a097431514e0..f91faf6050bf8 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -3899,6 +3899,7 @@ static bool handleScalarCast(EvalInfo &Info, const FPOptions FPO, const Expr *E,
}
}
+ Info.FFDiag(E, diag::note_invalid_subexpr_in_const_expr);
return false;
}
@@ -3959,9 +3960,8 @@ static bool constructAggregate(EvalInfo &Info, const FPOptions FPO,
uint64_t Size =
cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))->getZExtSize();
*Res = APValue(APValue::UninitArray(), Size, Size);
- for (int64_t I = Size - 1; I > -1; --I) {
+ for (int64_t I = Size - 1; I > -1; --I)
WorkList.emplace_back(&Res->getArrayInitializedElt(I), ElTy, 0u);
- }
continue;
}
if (Type->isRecordType()) {
@@ -4002,6 +4002,7 @@ static bool constructAggregate(EvalInfo &Info, const FPOptions FPO,
llvm::append_range(WorkList, ReverseList);
continue;
}
+ Info.FFDiag(E, diag::note_invalid_subexpr_in_const_expr);
return false;
}
return true;
@@ -4079,7 +4080,26 @@ static unsigned elementwiseSize(EvalInfo &Info, QualType BaseTy) {
return Size;
}
-static bool flattenAPValue(const ASTContext &Ctx, APValue Value,
+static bool hlslAggSplatHelper(EvalInfo &Info, const Expr *E, APValue &SrcVal,
+ QualType &SrcTy) {
+ SrcTy = E->getType();
+
+ if (!Evaluate(SrcVal, Info, E))
+ return false;
+
+ assert(SrcVal.isFloat() || SrcVal.isInt() ||
+ (SrcVal.isVector() && SrcVal.getVectorLength() == 1) &&
+ "Not a valid HLSLAggregateSplatCast.");
+
+ if (SrcVal.isVector()) {
+ assert(SrcTy->isVectorType() && "Type mismatch.");
+ SrcTy = SrcTy->castAs<VectorType>()->getElementType();
+ SrcVal = SrcVal.getVectorElt(0);
+ }
+ return true;
+}
+
+static bool flattenAPValue(EvalInfo &Info, const Expr *E, APValue Value,
QualType BaseTy, SmallVectorImpl<APValue> &Elements,
SmallVectorImpl<QualType> &Types, unsigned Size) {
@@ -4088,7 +4108,7 @@ static bool flattenAPValue(const ASTContext &Ctx, APValue Value,
while (!WorkList.empty() && Populated < Size) {
auto [Work, Type] = WorkList.pop_back_val();
- if (Work.isFloat() || Work.isInt()) { // todo what does this do with bool
+ if (Work.isFloat() || Work.isInt()) {
Elements.push_back(Work);
Types.push_back(Type);
Populated++;
@@ -4107,8 +4127,8 @@ static bool flattenAPValue(const ASTContext &Ctx, APValue Value,
}
if (Work.isArray()) {
assert(Type->isConstantArrayType() && "Type mismatch.");
- QualType ElTy =
- cast<ConstantArrayType>(Ctx.getAsArrayType(Type))->getElementType();
+ QualType ElTy = cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))
+ ->getElementType();
for (int64_t I = Work.getArraySize() - 1; I > -1; --I) {
WorkList.emplace_back(Work.getArrayInitializedElt(I), ElTy);
}
@@ -4148,6 +4168,7 @@ static bool flattenAPValue(const ASTContext &Ctx, APValue Value,
}
continue;
}
+ Info.FFDiag(E, diag::note_invalid_subexpr_in_const_expr);
return false;
}
return true;
@@ -4963,6 +4984,30 @@ handleLValueToRValueConversion(EvalInfo &Info, const Expr *Conv, QualType Type,
return Obj && extractSubobject(Info, Conv, Obj, LVal.Designator, RVal, AK);
}
+static bool hlslElementwiseCastHelper(EvalInfo &Info, const Expr *E,
+ QualType DestTy,
+ SmallVectorImpl<APValue> &SrcVals,
+ SmallVectorImpl<QualType> &SrcTypes) {
+ APValue Val;
+ if (!Evaluate(Val, Info, E))
+ return false;
+
+ // must be dealing with a record
+ if (Val.isLValue()) {
+ LValue LVal;
+ LVal.setFrom(Info.Ctx, Val);
+ if (!handleLValueToRValueConversion(Info, E, E->getType(), LVal, Val))
+ return false;
+ }
+
+ unsigned NEls = elementwiseSize(Info, DestTy);
+ // flatten the source
+ if (!flattenAPValue(Info, E, Val, E->getType(), SrcVals, SrcTypes, NEls))
+ return false;
+
+ return true;
+}
+
/// Perform an assignment of Val to LVal. Takes ownership of Val.
static bool handleAssignment(EvalInfo &Info, const Expr *E, const LValue &LVal,
QualType LValType, APValue &Val) {
@@ -11196,62 +11241,37 @@ bool RecordExprEvaluator::VisitCastExpr(const CastExpr *E) {
}
case CK_HLSLAggregateSplatCast: {
APValue Val;
- const Expr *SE = E->getSubExpr();
+ QualType ValTy;
- if (!Evaluate(Val, Info, SE))
- return Error(E);
+ if (!hlslAggSplatHelper(Info, E->getSubExpr(), Val, ValTy))
+ return false;
unsigned NEls = elementwiseSize(Info, E->getType());
- // flatten the source
- SmallVector<APValue, 1> SrcEls;
- SmallVector<QualType, 1> SrcTypes;
- if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes, NEls))
- return Error(E);
-
- // check there is only one and splat it
- assert(SrcEls.size() == 1);
- SmallVector<APValue> SplatEls(NEls, SrcEls[0]);
- SmallVector<QualType> SplatType(NEls, SrcTypes[0]);
-
- APValue Tmp;
- handleDefaultInitValue(E->getType(), Tmp);
+ // splat our Val
+ SmallVector<APValue> SplatEls(NEls, Val);
+ SmallVector<QualType> SplatType(NEls, ValTy);
// cast the elements and construct our struct result
const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
if (!constructAggregate(Info, FPO, E, Result, E->getType(), SplatEls,
SplatType))
- return Error(E);
+ return false;
return true;
}
case CK_HLSLElementwiseCast: {
- APValue Val;
- const Expr *SE = E->getSubExpr();
-
- if (!Evaluate(Val, Info, SE))
- return Error(E);
-
- // must be dealing with a record;
- if (Val.isLValue()) {
- LValue LVal;
- LVal.setFrom(Info.Ctx, Val);
- if (!handleLValueToRValueConversion(Info, SE, SE->getType(), LVal, Val))
- return false;
- }
-
- unsigned NEls = elementwiseSize(Info, E->getType());
- // flatten the source
SmallVector<APValue> SrcEls;
SmallVector<QualType> SrcTypes;
- if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes, NEls))
- return Error(E);
+
+ if (!hlslElementwiseCastHelper(Info, E->getSubExpr(), E->getType(), SrcEls,
+ SrcTypes))
+ return false;
// cast the elements and construct our struct result
const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
-
if (!constructAggregate(Info, FPO, E, Result, E->getType(), SrcEls,
SrcTypes))
- return Error(E);
+ return false;
return true;
}
@@ -11752,54 +11772,34 @@ bool VectorExprEvaluator::VisitCastExpr(const CastExpr *E) {
}
case CK_HLSLAggregateSplatCast: {
APValue Val;
+ QualType ValTy;
- if (!Evaluate(Val, Info, SE))
- return Error(E);
-
- // this cast doesn't handle splatting from scalars when result is a vector
- SmallVector<APValue, 1> Elements;
- SmallVector<QualType, 1> DestTypes = {VTy->getElementType()};
- SmallVector<QualType, 1> SrcTypes;
- if (!flattenAPValue(Info.Ctx, Val, SETy, Elements, SrcTypes, NElts))
- return Error(E);
+ if (!hlslAggSplatHelper(Info, SE, Val, ValTy))
+ return false;
- // check there is only one element and cast and splat it
- assert(Elements.size() == 1 &&
- "HLSLAggregateSplatCast RHS must contain one element");
+ // cast our Val once.
+ APValue Result;
const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
- SmallVector<APValue, 1> ResultEls(1);
- if (!handleElementwiseCast(Info, E, FPO, Elements, SrcTypes, DestTypes,
- ResultEls))
- return Error(E);
+ if (!handleScalarCast(Info, FPO, E, ValTy, VTy->getElementType(), Val,
+ Result))
+ return false;
- SmallVector<APValue, 4> SplatEls(NElts, ResultEls[0]);
+ SmallVector<APValue, 4> SplatEls(NElts, Result);
return Success(SplatEls, E);
}
case CK_HLSLElementwiseCast: {
- APValue Val;
-
- if (!Evaluate(Val, Info, SE))
- return Error(E);
+ SmallVector<APValue> SrcVals;
+ SmallVector<QualType> SrcTypes;
- // must be dealing with a record;
- if (Val.isLValue()) {
- LValue LVal;
- LVal.setFrom(Info.Ctx, Val);
- if (!handleLValueToRValueConversion(Info, SE, SE->getType(), LVal, Val))
- return false;
- }
+ if (!hlslElementwiseCastHelper(Info, SE, E->getType(), SrcVals, SrcTypes))
+ return false;
- SmallVector<APValue, 4> Elements;
- SmallVector<QualType, 4> DestTypes(NElts, VTy->getElementType());
- SmallVector<QualType, 4> SrcTypes;
- if (!flattenAPValue(Info.Ctx, Val, SETy, Elements, SrcTypes, NElts))
- return Error(E);
- // cast elements
const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+ SmallVector<QualType, 4> DestTypes(NElts, VTy->getElementType());
SmallVector<APValue, 4> ResultEls(NElts);
- if (!handleElementwiseCast(Info, E, FPO, Elements, SrcTypes, DestTypes,
+ if (!handleElementwiseCast(Info, E, FPO, SrcVals, SrcTypes, DestTypes,
ResultEls))
- return Error(E);
+ return false;
return Success(ResultEls, E);
}
default:
@@ -13579,57 +13579,36 @@ bool ArrayExprEvaluator::VisitCastExpr(const CastExpr *E) {
return ExprEvaluatorBaseTy::VisitCastExpr(E);
case CK_HLSLAggregateSplatCast: {
APValue Val;
+ QualType ValTy;
- if (!Evaluate(Val, Info, SE))
- return Error(E);
+ if (!hlslAggSplatHelper(Info, SE, Val, ValTy))
+ return false;
unsigned NEls = elementwiseSize(Info, E->getType());
- // flatten the source
- SmallVector<APValue, 1> SrcEls;
- SmallVector<QualType, 1> SrcTypes;
- if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes, NEls))
- return Error(E);
- // check there is only one and splat it
- assert(SrcEls.size() == 1);
- SmallVector<APValue> SplatEls(NEls, SrcEls[0]);
- SmallVector<QualType> SplatType(NEls, SrcTypes[0]);
+ SmallVector<APValue> SplatEls(NEls, Val);
+ SmallVector<QualType> SplatType(NEls, ValTy);
// cast the elements
const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
if (!constructAggregate(Info, FPO, E, Result, E->getType(), SplatEls,
SplatType))
- return Error(E);
+ return false;
return true;
}
case CK_HLSLElementwiseCast: {
- APValue Val;
-
- if (!Evaluate(Val, Info, SE))
- return Error(E);
-
- // must be dealing with a record;
- if (Val.isLValue()) {
- LValue LVal;
- LVal.setFrom(Info.Ctx, Val);
- if (!handleLValueToRValueConversion(Info, SE, SE->getType(), LVal, Val))
- return false;
- }
-
- unsigned NEls = elementwiseSize(Info, E->getType());
- // flatten the source
SmallVector<APValue> SrcEls;
SmallVector<QualType> SrcTypes;
- if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes, NEls))
- return Error(E);
+
+ if (!hlslElementwiseCastHelper(Info, SE, E->getType(), SrcEls, SrcTypes))
+ return false;
// cast the elements
const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
if (!constructAggregate(Info, FPO, E, Result, E->getType(), SrcEls,
SrcTypes))
- return Error(E);
-
+ return false;
return true;
}
}
@@ -17505,32 +17484,18 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
return Success(Val.getVectorElt(0), E);
}
case CK_HLSLElementwiseCast: {
- APValue Val;
-
- if (!Evaluate(Val, Info, SubExpr))
- return Error(E);
-
- // must be dealing with a record;
- if (Val.isLValue()) {
- LValue LVal;
- LVal.setFrom(Info.Ctx, Val);
- if (!handleLValueToRValueConversion(Info, SubExpr, SubExpr->getType(),
- LVal, Val))
- return false;
- }
+ SmallVector<APValue> SrcVals;
+ SmallVector<QualType> SrcTypes;
- SmallVector<APValue, 1> Elements;
- SmallVector<QualType, 1> DestTypes(1, DestType);
- SmallVector<QualType, 1> SrcTypes;
- if (!flattenAPValue(Info.Ctx, Val, SrcType, Elements, SrcTypes, 1))
- return Error(E);
+ if (!hlslElementwiseCastHelper(Info, SubExpr, DestType, SrcVals, SrcTypes))
+ return false;
// cast our single element
const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
APValue ResultVal;
- if (!handleScalarCast(Info, FPO, E, SrcTypes[0], DestTypes[0], Elements[0],
+ if (!handleScalarCast(Info, FPO, E, SrcTypes[0], DestType, SrcVals[0],
ResultVal))
- return Error(E);
+ return false;
return Success(ResultVal, E);
}
}
@@ -18112,33 +18077,20 @@ bool FloatExprEvaluator::VisitCastExpr(const CastExpr *E) {
return Success(Val.getVectorElt(0), E);
}
case CK_HLSLElementwiseCast: {
- APValue Val;
-
- if (!Evaluate(Val, Info, SubExpr))
- return Error(E);
-
- // must be dealing with a record;
- if (Val.isLValue()) {
- LValue LVal;
- LVal.setFrom(Info.Ctx, Val);
- if (!handleLValueToRValueConversion(Info, SubExpr, SubExpr->getType(),
- LVal, Val))
- return false;
- }
+ SmallVector<APValue> SrcVals;
+ SmallVector<QualType> SrcTypes;
- SmallVector<APValue, 1> Elements;
- SmallVector<QualType, 1> DestTypes(1, E->getType());
- SmallVector<QualType, 1> SrcTypes;
- if (!flattenAPValue(Info.Ctx, Val, SubExpr->getType(), Elements, SrcTypes,
- 1))
- return Error(E);
+ if (!hlslElementwiseCastHelper(Info, SubExpr, E->getType(), SrcVals,
+ SrcTypes))
+ return false;
+ APValue Val;
// cast our single element
const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
APValue ResultVal;
- if (!handleScalarCast(Info, FPO, E, SrcTypes[0], DestTypes[0], Elements[0],
+ if (!handleScalarCast(Info, FPO, E, SrcTypes[0], E->getType(), SrcVals[0],
ResultVal))
- return Error(E);
+ return false;
return Success(ResultVal, E);
}
}
diff --git a/clang/test/SemaHLSL/Types/AggregateSplatConstantExpr.hlsl b/clang/test/SemaHLSL/Types/AggregateSplatConstantExpr.hlsl
index 82fcf30b03709..a298d6024bd42 100644
--- a/clang/test/SemaHLSL/Types/AggregateSplatConstantExpr.hlsl
+++ b/clang/test/SemaHLSL/Types/AggregateSplatConstantExpr.hlsl
@@ -59,8 +59,8 @@ export void fn() {
// result type struct
// splat from a scalar
- constexpr double D = 100.6789;
- constexpr R SR = (R)D;
+ constexpr double D = 97.6789;
+ constexpr R SR = (R)(D + 3.0);
_Static_assert(SR.D == 100.6789, "Woo!");
_Static_assert(SR.U[0] == 100, "Woo!");
_Static_assert(SR.U[1] == 100, "Woo!");
More information about the cfe-commits
mailing list