[clang] [HLSL] add support for HLSLAggregateSplatCast and HLSLElementwiseCast to constant expression evaluator (PR #164700)
Sarah Spall via cfe-commits
cfe-commits at lists.llvm.org
Wed Oct 22 13:05:20 PDT 2025
https://github.com/spall created https://github.com/llvm/llvm-project/pull/164700
Add support to handle these casts in the constant expression evaluator.
- HLSLAggregateSplatCast
- HLSLElementwiseCast
- HLSLArrayRValue
Add tests
>From 276fca41ad8e81ce4189266c20d260646d6d5f4c Mon Sep 17 00:00:00 2001
From: Sarah Spall <sarahspall at microsoft.com>
Date: Fri, 31 Jan 2025 16:57:24 -0800
Subject: [PATCH] add support for HLSLAggregateSplatCast and
HLSLElementwiseCast to Constant expression evaluator. Add tests. Fix/Add
support for other minor necessary things.
---
clang/lib/AST/ExprConstant.cpp | 587 +++++++++++++++++-
.../Types/AggregateSplatConstantExpr.hlsl | 89 +++
.../BuiltinVector/TruncationConstantExpr.hlsl | 21 +
.../Types/ElementwiseCastConstantExpr.hlsl | 76 +++
4 files changed, 772 insertions(+), 1 deletion(-)
create mode 100644 clang/test/SemaHLSL/Types/AggregateSplatConstantExpr.hlsl
create mode 100644 clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 00aaaab957591..5dfb2b3e3491f 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -3828,6 +3828,333 @@ static bool CheckArraySize(EvalInfo &Info, const ConstantArrayType *CAT,
/*Diag=*/true);
}
+static bool handleScalarCast(EvalInfo &Info, const FPOptions FPO, const Expr *E,
+ QualType SourceTy, QualType DestTy,
+ APValue const &Original, APValue &Result) {
+ // boolean must be checked before integer
+ // since IsIntegerType() is true for bool
+ if (SourceTy->isBooleanType()) {
+ if (DestTy->isBooleanType()) {
+ Result = Original;
+ return true;
+ }
+ if (DestTy->isIntegerType() || DestTy->isRealFloatingType()) {
+ bool BoolResult;
+ if (!HandleConversionToBool(Original, BoolResult))
+ return false;
+ uint64_t IntResult = BoolResult;
+ Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy));
+ // TODO destty is wrong here if destty is float....
+ // can we use sourcety here?
+ }
+ if (DestTy->isFloatingType()) {
+ APValue Result2 = APValue(APFloat(0.0));
+ if (!HandleIntToFloatCast(Info, E, FPO,
+ Info.Ctx.getIntTypeForBitwidth(64, true),
+ Result.getInt(), DestTy, Result2.getFloat()))
+ return false;
+ Result = Result2;
+ }
+ return true;
+ }
+ if (SourceTy->isIntegerType()) {
+ if (DestTy->isRealFloatingType()) {
+ Result = APValue(APFloat(0.0));
+ return HandleIntToFloatCast(Info, E, FPO, SourceTy, Original.getInt(),
+ DestTy, Result.getFloat());
+ }
+ if (DestTy->isBooleanType()) {
+ bool BoolResult;
+ if (!HandleConversionToBool(Original, BoolResult))
+ return false;
+ uint64_t IntResult = BoolResult;
+ Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy));
+ return true;
+ }
+ if (DestTy->isIntegerType()) {
+ Result = APValue(
+ HandleIntToIntCast(Info, E, DestTy, SourceTy, Original.getInt()));
+ return true;
+ }
+ } else if (SourceTy->isRealFloatingType()) {
+ if (DestTy->isRealFloatingType()) {
+ Result = Original;
+ return HandleFloatToFloatCast(Info, E, SourceTy, DestTy,
+ Result.getFloat());
+ }
+ if (DestTy->isBooleanType()) {
+ bool BoolResult;
+ if (!HandleConversionToBool(Original, BoolResult))
+ return false;
+ uint64_t IntResult = BoolResult;
+ Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy));
+ return true;
+ }
+ if (DestTy->isIntegerType()) {
+ Result = APValue(APSInt());
+ return HandleFloatToIntCast(Info, E, SourceTy, Original.getFloat(),
+ DestTy, Result.getInt());
+ }
+ }
+
+ // Info.FFDiag(E, diag::err_convertvector_constexpr_unsupported_vector_cast)
+ // << SourceTy << DestTy;
+ return false;
+}
+
+// do the heavy lifting for casting to aggregate types
+// because we have to deal with bitfields specially
+static bool constructAggregate(EvalInfo &Info, const FPOptions FPO,
+ const Expr *E, APValue &Result,
+ QualType ResultType,
+ SmallVectorImpl<APValue> &Elements,
+ SmallVectorImpl<QualType> &ElTypes) {
+
+ SmallVector<std::tuple<APValue *, QualType, unsigned>> WorkList = {
+ {&Result, ResultType, 0}};
+
+ unsigned ElI = 0;
+ while (!WorkList.empty() && ElI < Elements.size()) {
+ auto [Res, Type, BitWidth] = WorkList.pop_back_val();
+
+ if (Type->isRealFloatingType() || Type->isBooleanType()) {
+ if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], Type, Elements[ElI],
+ *Res))
+ return false;
+ ElI++;
+ continue;
+ }
+ if (Type->isIntegerType()) {
+ if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], Type, Elements[ElI],
+ *Res))
+ return false;
+ if (BitWidth > 0) {
+ if (!Res->isInt())
+ return false;
+ APSInt &Int = Res->getInt();
+ unsigned OldBitWidth = Int.getBitWidth();
+ unsigned NewBitWidth = BitWidth;
+ if (NewBitWidth < OldBitWidth)
+ Int = Int.trunc(NewBitWidth).extend(OldBitWidth);
+ }
+ ElI++;
+ continue;
+ }
+ if (Type->isVectorType()) {
+ QualType ElTy = Type->castAs<VectorType>()->getElementType();
+ unsigned NumEl = Type->castAs<VectorType>()->getNumElements();
+ SmallVector<APValue> Vals(NumEl);
+ for (unsigned I = 0; I < NumEl; ++I) {
+ if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], ElTy, Elements[ElI],
+ Vals[I]))
+ return false;
+ ElI++;
+ }
+ *Res = APValue(Vals.data(), NumEl);
+ continue;
+ }
+ if (Type->isConstantArrayType()) {
+ QualType ElTy = cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))
+ ->getElementType();
+ uint64_t Size =
+ cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))->getZExtSize();
+ *Res = APValue(APValue::UninitArray(), Size, Size);
+ for (int64_t I = Size - 1; I > -1; --I) {
+ WorkList.emplace_back(&Res->getArrayInitializedElt(I), ElTy, 0u);
+ }
+ continue;
+ }
+ if (Type->isRecordType()) {
+ const RecordDecl *RD = Type->getAsRecordDecl();
+
+ unsigned NumBases = 0;
+ if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD))
+ NumBases = CXXRD->getNumBases();
+
+ *Res = APValue(APValue::UninitStruct(), NumBases,
+ std::distance(RD->field_begin(), RD->field_end()));
+
+ SmallVector<std::tuple<APValue *, QualType, unsigned>> ReverseList;
+ // we need to traverse backwards
+ // Visit the base classes.
+ if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+ // todo assert there is only 1 base at most
+ for (size_t I = 0, E = CXXRD->getNumBases(); I != E; ++I) {
+ const CXXBaseSpecifier &BS = CXXRD->bases_begin()[I];
+ ReverseList.emplace_back(&Res->getStructBase(I), BS.getType(), 0u);
+ }
+ }
+
+ // Visit the fields.
+ for (FieldDecl *FD : RD->fields()) {
+ unsigned FDBW = 0;
+ if (FD->isUnnamedBitField())
+ continue;
+ if (FD->isBitField()) {
+ FDBW = FD->getBitWidthValue();
+ }
+
+ ReverseList.emplace_back(&Res->getStructField(FD->getFieldIndex()),
+ FD->getType(), FDBW);
+ }
+
+ std::reverse(ReverseList.begin(), ReverseList.end());
+ llvm::append_range(WorkList, ReverseList);
+ continue;
+ }
+ return false;
+ }
+ return true;
+}
+
+static bool handleElementwiseCast(EvalInfo &Info, const Expr *E,
+ const FPOptions FPO,
+ SmallVectorImpl<APValue> &Elements,
+ SmallVectorImpl<QualType> &SrcTypes,
+ SmallVectorImpl<QualType> &DestTypes,
+ SmallVectorImpl<APValue> &Results) {
+
+ assert((Elements.size() == SrcTypes.size()) &&
+ (Elements.size() == DestTypes.size()));
+
+ for (unsigned I = 0, ESz = Elements.size(); I < ESz; ++I) {
+ APValue Original = Elements[I];
+ QualType SourceTy = SrcTypes[I];
+ QualType DestTy = DestTypes[I];
+
+ if (!handleScalarCast(Info, FPO, E, SourceTy, DestTy, Original, Results[I]))
+ return false;
+ }
+ return true;
+}
+
+static unsigned elementwiseSize(EvalInfo &Info, QualType BaseTy) {
+
+ SmallVector<QualType> WorkList = {BaseTy};
+
+ unsigned Size = 0;
+ while (!WorkList.empty()) {
+ QualType Type = WorkList.pop_back_val();
+ if (Type->isRealFloatingType() || Type->isIntegerType() ||
+ Type->isBooleanType()) {
+ ++Size;
+ continue;
+ }
+ if (Type->isVectorType()) {
+ unsigned NumEl = Type->castAs<VectorType>()->getNumElements();
+ Size += NumEl;
+ continue;
+ }
+ if (Type->isConstantArrayType()) {
+ QualType ElTy = cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))
+ ->getElementType();
+ uint64_t Size =
+ cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))->getZExtSize();
+ for (uint64_t I = 0; I < Size; ++I) {
+ WorkList.push_back(ElTy);
+ }
+ continue;
+ }
+ if (Type->isRecordType()) {
+ const RecordDecl *RD = Type->getAsRecordDecl();
+ // const ASTRecordLayout &Layout = Info.Ctx.getASTRecordLayout(RD);
+
+ // Visit the base classes.
+ if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+ // todo assert there is only 1 base at most
+ for (size_t I = 0, E = CXXRD->getNumBases(); I != E; ++I) {
+ const CXXBaseSpecifier &BS = CXXRD->bases_begin()[I];
+ WorkList.push_back(BS.getType());
+ }
+ }
+
+ // visit the fields.
+ for (FieldDecl *FD : RD->fields()) {
+ if (FD->isUnnamedBitField())
+ continue;
+ WorkList.push_back(FD->getType());
+ }
+ continue;
+ }
+ }
+ return Size;
+}
+
+static bool flattenAPValue(const ASTContext &Ctx, APValue Value,
+ QualType BaseTy, SmallVectorImpl<APValue> &Elements,
+ SmallVectorImpl<QualType> &Types, unsigned Size) {
+
+ SmallVector<std::pair<APValue, QualType>> WorkList = {{Value, BaseTy}};
+ unsigned Populated = 0;
+ while (!WorkList.empty() && Populated < Size) {
+ auto [Work, Type] = WorkList.pop_back_val();
+
+ if (Work.isFloat() || Work.isInt()) { // todo what does this do with bool
+ Elements.push_back(Work);
+ Types.push_back(Type);
+ Populated++;
+ continue;
+ }
+ if (Work.isVector()) {
+ assert(Type->isVectorType() && "Type mismatch.");
+ QualType ElTy = Type->castAs<VectorType>()->getElementType();
+ for (unsigned I = 0; I < Work.getVectorLength() && Populated < Size;
+ I++) {
+ Elements.push_back(Work.getVectorElt(I));
+ Types.push_back(ElTy);
+ Populated++;
+ }
+ continue;
+ }
+ if (Work.isArray()) {
+ assert(Type->isConstantArrayType() && "Type mismatch.");
+ QualType ElTy =
+ cast<ConstantArrayType>(Ctx.getAsArrayType(Type))->getElementType();
+ for (int64_t I = Work.getArraySize() - 1; I > -1; --I) {
+ WorkList.emplace_back(Work.getArrayInitializedElt(I), ElTy);
+ }
+ continue;
+ }
+
+ if (Work.isStruct()) {
+ assert(Type->isRecordType() && "Type mismatch.");
+
+ const RecordDecl *RD = Type->getAsRecordDecl();
+
+ SmallVector<std::pair<APValue, QualType>> ReverseList;
+ // Visit the fields.
+ for (FieldDecl *FD : RD->fields()) {
+ if (FD->isUnnamedBitField())
+ continue;
+ // if (FD->isBitField()) {
+ ReverseList.emplace_back(Work.getStructField(FD->getFieldIndex()),
+ FD->getType());
+ }
+
+ std::reverse(ReverseList.begin(), ReverseList.end());
+ llvm::append_range(WorkList, ReverseList);
+
+ // Visit the base classes.
+ if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+ if (CXXRD->getNumBases() > 0) {
+ assert(CXXRD->getNumBases() == 1);
+ const CXXBaseSpecifier &BS = CXXRD->bases_begin()[0];
+ const APValue &Base = Work.getStructBase(0);
+
+ // Can happen in error cases.
+ if (!Base.isStruct())
+ return false;
+
+ WorkList.emplace_back(Base, BS.getType());
+ }
+ }
+ continue;
+ }
+ return false;
+ }
+ return true;
+}
+
namespace {
/// A handle to a complete object (an object that is not a subobject of
/// another object).
@@ -8666,6 +8993,25 @@ class ExprEvaluatorBase
case CK_UserDefinedConversion:
return StmtVisitorTy::Visit(E->getSubExpr());
+ case CK_HLSLArrayRValue: {
+ const Expr *SubExpr = E->getSubExpr();
+ if (!SubExpr->isGLValue()) {
+ APValue Val;
+ if (!Evaluate(Val, Info, SubExpr))
+ return false;
+ return DerivedSuccess(Val, E);
+ }
+
+ LValue LVal;
+ if (!EvaluateLValue(SubExpr, LVal, Info))
+ return false;
+ APValue RVal;
+ // Note, we use the subexpression's type in order to retain cv-qualifiers.
+ if (!handleLValueToRValueConversion(Info, E, SubExpr->getType(), LVal,
+ RVal))
+ return false;
+ return DerivedSuccess(RVal, E);
+ }
case CK_LValueToRValue: {
LValue LVal;
if (!EvaluateLValue(E->getSubExpr(), LVal, Info))
@@ -10850,6 +11196,67 @@ bool RecordExprEvaluator::VisitCastExpr(const CastExpr *E) {
Result = *Value;
return true;
}
+ case CK_HLSLAggregateSplatCast: {
+ APValue Val;
+ const Expr *SE = E->getSubExpr();
+
+ if (!Evaluate(Val, Info, SE))
+ return Error(E);
+
+ unsigned NEls = elementwiseSize(Info, E->getType());
+ // flatten the source
+ SmallVector<APValue, 1> SrcEls;
+ SmallVector<QualType, 1> SrcTypes;
+ if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes, NEls))
+ return Error(E);
+
+ // check there is only one and splat it
+ assert(SrcEls.size() == 1);
+ SmallVector<APValue> SplatEls(NEls, SrcEls[0]);
+ SmallVector<QualType> SplatType(NEls, SrcTypes[0]);
+
+ APValue Tmp;
+ handleDefaultInitValue(E->getType(), Tmp);
+
+ // cast the elements and construct our struct result
+ const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+ if (!constructAggregate(Info, FPO, E, Result, E->getType(), SplatEls,
+ SplatType))
+ return Error(E);
+
+ return true;
+ }
+ case CK_HLSLElementwiseCast: {
+ APValue Val;
+ const Expr *SE = E->getSubExpr();
+
+ if (!Evaluate(Val, Info, SE))
+ return Error(E);
+
+ // must be dealing with a record;
+ if (Val.isLValue()) {
+ LValue LVal;
+ LVal.setFrom(Info.Ctx, Val);
+ if (!handleLValueToRValueConversion(Info, SE, SE->getType(), LVal, Val))
+ return false;
+ }
+
+ // flatten the source
+ SmallVector<APValue> SrcEls;
+ SmallVector<QualType> SrcTypes;
+ if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes,
+ UINT_MAX))
+ return Error(E);
+
+ // cast the elements and construct our struct result
+ const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+
+ if (!constructAggregate(Info, FPO, E, Result, E->getType(), SrcEls,
+ SrcTypes))
+ return Error(E);
+
+ return true;
+ }
}
}
@@ -11345,6 +11752,58 @@ bool VectorExprEvaluator::VisitCastExpr(const CastExpr *E) {
Elements.push_back(Val.getVectorElt(I));
return Success(Elements, E);
}
+ case CK_HLSLAggregateSplatCast: {
+ APValue Val;
+
+ if (!Evaluate(Val, Info, SE))
+ return Error(E);
+
+ // this cast doesn't handle splatting from scalars when result is a vector
+ SmallVector<APValue, 1> Elements;
+ SmallVector<QualType, 1> DestTypes = {VTy->getElementType()};
+ SmallVector<QualType, 1> SrcTypes;
+ if (!flattenAPValue(Info.Ctx, Val, SETy, Elements, SrcTypes, NElts))
+ return Error(E);
+
+ // check there is only one element and cast and splat it
+ assert(Elements.size() == 1 &&
+ "HLSLAggregateSplatCast RHS must contain one element");
+ const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+ SmallVector<APValue, 1> ResultEls(1);
+ if (!handleElementwiseCast(Info, E, FPO, Elements, SrcTypes, DestTypes,
+ ResultEls))
+ return Error(E);
+
+ SmallVector<APValue, 4> SplatEls(NElts, ResultEls[0]);
+ return Success(SplatEls, E);
+ }
+ case CK_HLSLElementwiseCast: {
+ APValue Val;
+
+ if (!Evaluate(Val, Info, SE))
+ return Error(E);
+
+ // must be dealing with a record;
+ if (Val.isLValue()) {
+ LValue LVal;
+ LVal.setFrom(Info.Ctx, Val);
+ if (!handleLValueToRValueConversion(Info, SE, SE->getType(), LVal, Val))
+ return false;
+ }
+
+ SmallVector<APValue, 4> Elements;
+ SmallVector<QualType, 4> DestTypes(NElts, VTy->getElementType());
+ SmallVector<QualType, 4> SrcTypes;
+ if (!flattenAPValue(Info.Ctx, Val, SETy, Elements, SrcTypes, NElts))
+ return Error(E);
+ // cast elements
+ const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+ SmallVector<APValue, 4> ResultEls(NElts);
+ if (!handleElementwiseCast(Info, E, FPO, Elements, SrcTypes, DestTypes,
+ ResultEls))
+ return Error(E);
+ return Success(ResultEls, E);
+ }
default:
return ExprEvaluatorBaseTy::VisitCastExpr(E);
}
@@ -13029,6 +13488,7 @@ namespace {
bool VisitCallExpr(const CallExpr *E) {
return handleCallExpr(E, Result, &This);
}
+ bool VisitCastExpr(const CastExpr *E);
bool VisitInitListExpr(const InitListExpr *E,
QualType AllocType = QualType());
bool VisitArrayInitLoopExpr(const ArrayInitLoopExpr *E);
@@ -13099,6 +13559,70 @@ static bool MaybeElementDependentArrayFiller(const Expr *FillerExpr) {
return true;
}
+bool ArrayExprEvaluator::VisitCastExpr(const CastExpr *E) {
+ const Expr *SE = E->getSubExpr();
+
+ switch (E->getCastKind()) {
+ default:
+ return ExprEvaluatorBaseTy::VisitCastExpr(E);
+ case CK_HLSLAggregateSplatCast: {
+ APValue Val;
+
+ if (!Evaluate(Val, Info, SE))
+ return Error(E);
+
+ unsigned NEls = elementwiseSize(Info, E->getType());
+ // flatten the source
+ SmallVector<APValue, 1> SrcEls;
+ SmallVector<QualType, 1> SrcTypes;
+ if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes, NEls))
+ return Error(E);
+
+ // check there is only one and splat it
+ assert(SrcEls.size() == 1);
+ SmallVector<APValue> SplatEls(NEls, SrcEls[0]);
+ SmallVector<QualType> SplatType(NEls, SrcTypes[0]);
+
+ // cast the elements
+ const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+ if (!constructAggregate(Info, FPO, E, Result, E->getType(), SplatEls,
+ SplatType))
+ return Error(E);
+
+ return true;
+ }
+ case CK_HLSLElementwiseCast: {
+ APValue Val;
+
+ if (!Evaluate(Val, Info, SE))
+ return Error(E);
+
+ // must be dealing with a record;
+ if (Val.isLValue()) {
+ LValue LVal;
+ LVal.setFrom(Info.Ctx, Val);
+ if (!handleLValueToRValueConversion(Info, SE, SE->getType(), LVal, Val))
+ return false;
+ }
+
+ // flatten the source
+ SmallVector<APValue> SrcEls;
+ SmallVector<QualType> SrcTypes;
+ if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes,
+ UINT_MAX))
+ return Error(E);
+
+ // cast the elements
+ const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+ if (!constructAggregate(Info, FPO, E, Result, E->getType(), SrcEls,
+ SrcTypes))
+ return Error(E);
+
+ return true;
+ }
+ }
+}
+
bool ArrayExprEvaluator::VisitInitListExpr(const InitListExpr *E,
QualType AllocType) {
const ConstantArrayType *CAT = Info.Ctx.getAsConstantArrayType(
@@ -16801,7 +17325,6 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
case CK_NoOp:
case CK_LValueToRValueBitCast:
case CK_HLSLArrayRValue:
- case CK_HLSLElementwiseCast:
return ExprEvaluatorBaseTy::VisitCastExpr(E);
case CK_MemberPointerToBoolean:
@@ -16948,6 +17471,35 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
return Error(E);
return Success(Val.getVectorElt(0), E);
}
+ case CK_HLSLElementwiseCast: {
+ APValue Val;
+
+ if (!Evaluate(Val, Info, SubExpr))
+ return Error(E);
+
+ // must be dealing with a record;
+ if (Val.isLValue()) {
+ LValue LVal;
+ LVal.setFrom(Info.Ctx, Val);
+ if (!handleLValueToRValueConversion(Info, SubExpr, SubExpr->getType(),
+ LVal, Val))
+ return false;
+ }
+
+ SmallVector<APValue, 1> Elements;
+ SmallVector<QualType, 1> DestTypes(1, DestType);
+ SmallVector<QualType, 1> SrcTypes;
+ if (!flattenAPValue(Info.Ctx, Val, SrcType, Elements, SrcTypes, 1))
+ return Error(E);
+
+ // cast our single element
+ const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+ APValue ResultVal;
+ if (!handleScalarCast(Info, FPO, E, SrcTypes[0], DestTypes[0], Elements[0],
+ ResultVal))
+ return Error(E);
+ return Success(ResultVal, E);
+ }
}
llvm_unreachable("unknown cast resulting in integral value");
@@ -17485,6 +18037,9 @@ bool FloatExprEvaluator::VisitCastExpr(const CastExpr *E) {
default:
return ExprEvaluatorBaseTy::VisitCastExpr(E);
+ case CK_HLSLAggregateSplatCast:
+ llvm_unreachable("invalid cast kind for floating value");
+
case CK_IntegralToFloating: {
APSInt IntResult;
const FPOptions FPO = E->getFPFeaturesInEffect(
@@ -17523,6 +18078,36 @@ bool FloatExprEvaluator::VisitCastExpr(const CastExpr *E) {
return Error(E);
return Success(Val.getVectorElt(0), E);
}
+ case CK_HLSLElementwiseCast: {
+ APValue Val;
+
+ if (!Evaluate(Val, Info, SubExpr))
+ return Error(E);
+
+ // must be dealing with a record;
+ if (Val.isLValue()) {
+ LValue LVal;
+ LVal.setFrom(Info.Ctx, Val);
+ if (!handleLValueToRValueConversion(Info, SubExpr, SubExpr->getType(),
+ LVal, Val))
+ return false;
+ }
+
+ SmallVector<APValue, 1> Elements;
+ SmallVector<QualType, 1> DestTypes(1, E->getType());
+ SmallVector<QualType, 1> SrcTypes;
+ if (!flattenAPValue(Info.Ctx, Val, SubExpr->getType(), Elements, SrcTypes,
+ 1))
+ return Error(E);
+
+ // cast our single element
+ const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+ APValue ResultVal;
+ if (!handleScalarCast(Info, FPO, E, SrcTypes[0], DestTypes[0], Elements[0],
+ ResultVal))
+ return Error(E);
+ return Success(ResultVal, E);
+ }
}
}
diff --git a/clang/test/SemaHLSL/Types/AggregateSplatConstantExpr.hlsl b/clang/test/SemaHLSL/Types/AggregateSplatConstantExpr.hlsl
new file mode 100644
index 0000000000000..7df41f24ee0d9
--- /dev/null
+++ b/clang/test/SemaHLSL/Types/AggregateSplatConstantExpr.hlsl
@@ -0,0 +1,89 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -fnative-half-type -std=hlsl202x -verify %s
+
+// expected-no-diagnostics
+
+struct Base {
+ double D;
+ uint64_t2 U;
+ int16_t I : 5;
+ uint16_t I2: 5;
+};
+
+struct R : Base {
+ int G : 10;
+ int : 30;
+ float F;
+};
+
+struct B1 {
+ float A;
+ float B;
+};
+
+struct B2 : B1 {
+ int C;
+ int D;
+ bool BB;
+};
+
+// tests for HLSLAggregateSplatCast
+export void fn() {
+ // result type vector
+ // splat from a vector of size 1
+
+ constexpr float1 Y = {1.0};
+ constexpr int1 A1 = {1};
+ constexpr float4 F4 = (float4)Y;
+ _Static_assert(F4[0] == 1.0, "Woo!");
+ _Static_assert(F4[1] == 1.0, "Woo!");
+ _Static_assert(F4[2] == 1.0, "Woo!");
+ _Static_assert(F4[3] == 1.0, "Woo!");
+
+ // result type array
+ // splat from a scalar
+ constexpr float F = 3.33;
+ constexpr int B6[6] = (int[6])F;
+ _Static_assert(B6[0] == 3, "Woo!");
+ _Static_assert(B6[1] == 3, "Woo!");
+ _Static_assert(B6[2] == 3, "Woo!");
+ _Static_assert(B6[3] == 3, "Woo!");
+ _Static_assert(B6[4] == 3, "Woo!");
+ _Static_assert(B6[5] == 3, "Woo!");
+
+ // splat from a vector of size 1
+ constexpr uint64_t2 A7[2] = (uint64_t2[2])A1;
+ _Static_assert(A7[0][0] == 1, "Woo!");
+ _Static_assert(A7[0][1] == 1, "Woo!");
+ _Static_assert(A7[1][0] == 1, "Woo!");
+ _Static_assert(A7[1][1] == 1, "Woo!");
+
+ // result type struct
+ // splat from a scalar
+ constexpr double D = 100.6789;
+ constexpr R SR = (R)D;
+ _Static_assert(SR.D == 100.6789, "Woo!");
+ _Static_assert(SR.U[0] == 100, "Woo!");
+ _Static_assert(SR.U[1] == 100, "Woo!");
+ _Static_assert(SR.I == 4, "Woo!");
+ _Static_assert(SR.I2 == 4, "Woo!");
+ _Static_assert(SR.G == 100, "Woo!");
+ _Static_assert(SR.F == 100.6789, "Woo!");
+
+ // splat from a vector of size 1
+ constexpr float1 A100 = {1000.1111};
+ constexpr B2 SB2 = (B2)A100;
+ _Static_assert(SB2.A == 1000.1111, "Woo!");
+ _Static_assert(SB2.B == 1000.1111, "Woo!");
+ _Static_assert(SB2.C == 1000, "Woo!");
+ _Static_assert(SB2.D == 1000, "Woo!");
+ _Static_assert(SB2.BB == true, "Woo!");
+
+ // splat from a bool to an int and float etc
+ constexpr bool B = true;
+ constexpr B2 SB3 = (B2)B;
+ _Static_assert(SB3.A == 1.0, "Woo!");
+ _Static_assert(SB3.B == 1.0, "Woo!");
+ _Static_assert(SB3.C == 1, "Woo!");
+ _Static_assert(SB3.D == 1, "Woo!");
+ _Static_assert(SB3.BB == true, "Woo!");
+}
diff --git a/clang/test/SemaHLSL/Types/BuiltinVector/TruncationConstantExpr.hlsl b/clang/test/SemaHLSL/Types/BuiltinVector/TruncationConstantExpr.hlsl
index 918daa03d8032..85b7271fd2c8c 100644
--- a/clang/test/SemaHLSL/Types/BuiltinVector/TruncationConstantExpr.hlsl
+++ b/clang/test/SemaHLSL/Types/BuiltinVector/TruncationConstantExpr.hlsl
@@ -5,7 +5,28 @@
// Note: these tests are a bit awkward because at time of writing we don't have a
// good way to constexpr `any` for bool vector conditions, and the condition for
// _Static_assert must be an integral constant.
+
+struct S {
+ int3 A;
+ float B;
+};
+
export void fn() {
+
+ _Static_assert(((float4)(int[6]){1,2,3,4,5,6}).x == 1.0, "Woo!");
+
+ // This compiling successfully verifies that the array constant expression
+ // gets truncated to a float at compile time for instantiation via the
+ // flat cast
+ _Static_assert(((int)(int[2]){1,2}) == 1, "Woo!");
+
+ // This compiling successfully verifies that the struct constant expression
+ // gets truncated to an integer at compile time for instatiation via the
+ // flat cast
+ _Static_assert(((int)(S){{1,2,3},1.0}) == 1, "Woo!");
+
+ _Static_assert(((float)(float[2]){1.0,2.0}) == 1.0, "Woo!");
+
// This compiling successfully verifies that the vector constant expression
// gets truncated to an integer at compile time for instantiation.
_Static_assert(((int)1.xxxx) + 0 == 1, "Woo!");
diff --git a/clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl b/clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl
new file mode 100644
index 0000000000000..6f697e8097a21
--- /dev/null
+++ b/clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl
@@ -0,0 +1,76 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -fnative-half-type -std=hlsl202x -verify %s
+
+// expected-no-diagnostics
+
+struct Base {
+ double D;
+ uint64_t2 U;
+ int16_t I : 5;
+ uint16_t I2: 5;
+};
+
+struct R : Base {
+ int G : 10;
+ int : 30;
+ float F;
+};
+
+struct B1 {
+ float A;
+ float B;
+};
+
+struct B2 : B1 {
+ int C;
+ int D;
+ bool BB;
+};
+
+export void fn() {
+
+ // truncation tests
+ // result type int
+ // truncate from struct
+ constexpr B1 SB1 = {1.0, 3.0};
+ constexpr float Blah = SB1.A;
+ constexpr int X = (int)SB1;
+ _Static_assert(X == 1, "Woo!");
+
+ // result type float
+ // truncate from array
+ constexpr B1 Arr[2] = {4.0, 3.0, 2.0, 1.0};
+ constexpr float F = (float)Arr;
+ _Static_assert(F == 4.0, "Woo!");
+
+ // result type vector
+ // truncate from array of vector
+ constexpr int2 Arr2[2] = {5,6,7,8};
+ constexpr int2 I2 = (int2)Arr2;
+ _Static_assert(I2[0] == 5, "Woo!");
+ _Static_assert(I2[1] == 6, "Woo!");
+
+ // lhs and rhs are same "size" tests
+
+ // result type vector from array
+ constexpr int4 I4 = (int4)Arr;
+ _Static_assert(I4[0] == 4, "Woo!");
+ _Static_assert(I4[1] == 3, "Woo!");
+ _Static_assert(I4[2] == 2, "Woo!");
+ _Static_assert(I4[3] == 1, "Woo!");
+
+ // result type array from vector
+ constexpr double3 D3 = {100.11, 200.11, 300.11};
+ constexpr float FArr[3] = (float[3])D3;
+ _Static_assert(FArr[0] == 100.11, "Woo!");
+ _Static_assert(FArr[1] == 200.11, "Woo!");
+ _Static_assert(FArr[2] == 300.11, "Woo!");
+
+ // result type struct from struct
+ constexpr B2 SB2 = {5.5, 6.5, 1000, 5000, false};
+ constexpr Base SB = (Base)SB2;
+ _Static_assert(SB.D == 5.5, "Woo!");
+ _Static_assert(SB.U[0] == 6, "Woo!");
+ _Static_assert(SB.U[1] == 1000, "Woo!");
+ _Static_assert(SB.I == 8, "Woo!");
+ _Static_assert(SB.I2 == 0, "Woo!");
+}
More information about the cfe-commits
mailing list