[clang] [HLSL] add support for HLSLAggregateSplatCast and HLSLElementwiseCast to constant expression evaluator (PR #164700)

Sarah Spall via cfe-commits cfe-commits at lists.llvm.org
Wed Oct 22 13:05:20 PDT 2025


https://github.com/spall created https://github.com/llvm/llvm-project/pull/164700

Add support to handle these casts in the constant expression evaluator. 

- HLSLAggregateSplatCast
- HLSLElementwiseCast
- HLSLArrayRValue

Add tests 

>From 276fca41ad8e81ce4189266c20d260646d6d5f4c Mon Sep 17 00:00:00 2001
From: Sarah Spall <sarahspall at microsoft.com>
Date: Fri, 31 Jan 2025 16:57:24 -0800
Subject: [PATCH] add support for HLSLAggregateSplatCast and
 HLSLElementwiseCast to Constant expression evaluator. Add tests. Fix/Add
 support for other minor necessary things.

---
 clang/lib/AST/ExprConstant.cpp                | 587 +++++++++++++++++-
 .../Types/AggregateSplatConstantExpr.hlsl     |  89 +++
 .../BuiltinVector/TruncationConstantExpr.hlsl |  21 +
 .../Types/ElementwiseCastConstantExpr.hlsl    |  76 +++
 4 files changed, 772 insertions(+), 1 deletion(-)
 create mode 100644 clang/test/SemaHLSL/Types/AggregateSplatConstantExpr.hlsl
 create mode 100644 clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl

diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 00aaaab957591..5dfb2b3e3491f 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -3828,6 +3828,333 @@ static bool CheckArraySize(EvalInfo &Info, const ConstantArrayType *CAT,
       /*Diag=*/true);
 }
 
+static bool handleScalarCast(EvalInfo &Info, const FPOptions FPO, const Expr *E,
+                             QualType SourceTy, QualType DestTy,
+                             APValue const &Original, APValue &Result) {
+  // boolean must be checked before integer
+  // since IsIntegerType() is true for bool
+  if (SourceTy->isBooleanType()) {
+    if (DestTy->isBooleanType()) {
+      Result = Original;
+      return true;
+    }
+    if (DestTy->isIntegerType() || DestTy->isRealFloatingType()) {
+      bool BoolResult;
+      if (!HandleConversionToBool(Original, BoolResult))
+        return false;
+      uint64_t IntResult = BoolResult;
+      Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy));
+      // TODO destty is wrong here if destty is float....
+      // can we use sourcety here?
+    }
+    if (DestTy->isFloatingType()) {
+      APValue Result2 = APValue(APFloat(0.0));
+      if (!HandleIntToFloatCast(Info, E, FPO,
+                                Info.Ctx.getIntTypeForBitwidth(64, true),
+                                Result.getInt(), DestTy, Result2.getFloat()))
+        return false;
+      Result = Result2;
+    }
+    return true;
+  }
+  if (SourceTy->isIntegerType()) {
+    if (DestTy->isRealFloatingType()) {
+      Result = APValue(APFloat(0.0));
+      return HandleIntToFloatCast(Info, E, FPO, SourceTy, Original.getInt(),
+                                  DestTy, Result.getFloat());
+    }
+    if (DestTy->isBooleanType()) {
+      bool BoolResult;
+      if (!HandleConversionToBool(Original, BoolResult))
+        return false;
+      uint64_t IntResult = BoolResult;
+      Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy));
+      return true;
+    }
+    if (DestTy->isIntegerType()) {
+      Result = APValue(
+          HandleIntToIntCast(Info, E, DestTy, SourceTy, Original.getInt()));
+      return true;
+    }
+  } else if (SourceTy->isRealFloatingType()) {
+    if (DestTy->isRealFloatingType()) {
+      Result = Original;
+      return HandleFloatToFloatCast(Info, E, SourceTy, DestTy,
+                                    Result.getFloat());
+    }
+    if (DestTy->isBooleanType()) {
+      bool BoolResult;
+      if (!HandleConversionToBool(Original, BoolResult))
+        return false;
+      uint64_t IntResult = BoolResult;
+      Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy));
+      return true;
+    }
+    if (DestTy->isIntegerType()) {
+      Result = APValue(APSInt());
+      return HandleFloatToIntCast(Info, E, SourceTy, Original.getFloat(),
+                                  DestTy, Result.getInt());
+    }
+  }
+
+  // Info.FFDiag(E, diag::err_convertvector_constexpr_unsupported_vector_cast)
+  //   << SourceTy << DestTy;
+  return false;
+}
+
+// do the heavy lifting for casting to aggregate types
+// because we have to deal with bitfields specially
+static bool constructAggregate(EvalInfo &Info, const FPOptions FPO,
+                               const Expr *E, APValue &Result,
+                               QualType ResultType,
+                               SmallVectorImpl<APValue> &Elements,
+                               SmallVectorImpl<QualType> &ElTypes) {
+
+  SmallVector<std::tuple<APValue *, QualType, unsigned>> WorkList = {
+      {&Result, ResultType, 0}};
+
+  unsigned ElI = 0;
+  while (!WorkList.empty() && ElI < Elements.size()) {
+    auto [Res, Type, BitWidth] = WorkList.pop_back_val();
+
+    if (Type->isRealFloatingType() || Type->isBooleanType()) {
+      if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], Type, Elements[ElI],
+                            *Res))
+        return false;
+      ElI++;
+      continue;
+    }
+    if (Type->isIntegerType()) {
+      if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], Type, Elements[ElI],
+                            *Res))
+        return false;
+      if (BitWidth > 0) {
+        if (!Res->isInt())
+          return false;
+        APSInt &Int = Res->getInt();
+        unsigned OldBitWidth = Int.getBitWidth();
+        unsigned NewBitWidth = BitWidth;
+        if (NewBitWidth < OldBitWidth)
+          Int = Int.trunc(NewBitWidth).extend(OldBitWidth);
+      }
+      ElI++;
+      continue;
+    }
+    if (Type->isVectorType()) {
+      QualType ElTy = Type->castAs<VectorType>()->getElementType();
+      unsigned NumEl = Type->castAs<VectorType>()->getNumElements();
+      SmallVector<APValue> Vals(NumEl);
+      for (unsigned I = 0; I < NumEl; ++I) {
+        if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], ElTy, Elements[ElI],
+                              Vals[I]))
+          return false;
+        ElI++;
+      }
+      *Res = APValue(Vals.data(), NumEl);
+      continue;
+    }
+    if (Type->isConstantArrayType()) {
+      QualType ElTy = cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))
+                          ->getElementType();
+      uint64_t Size =
+          cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))->getZExtSize();
+      *Res = APValue(APValue::UninitArray(), Size, Size);
+      for (int64_t I = Size - 1; I > -1; --I) {
+        WorkList.emplace_back(&Res->getArrayInitializedElt(I), ElTy, 0u);
+      }
+      continue;
+    }
+    if (Type->isRecordType()) {
+      const RecordDecl *RD = Type->getAsRecordDecl();
+
+      unsigned NumBases = 0;
+      if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD))
+        NumBases = CXXRD->getNumBases();
+
+      *Res = APValue(APValue::UninitStruct(), NumBases,
+                     std::distance(RD->field_begin(), RD->field_end()));
+
+      SmallVector<std::tuple<APValue *, QualType, unsigned>> ReverseList;
+      // we need to traverse backwards
+      // Visit the base classes.
+      if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+        // todo assert there is only 1 base at most
+        for (size_t I = 0, E = CXXRD->getNumBases(); I != E; ++I) {
+          const CXXBaseSpecifier &BS = CXXRD->bases_begin()[I];
+          ReverseList.emplace_back(&Res->getStructBase(I), BS.getType(), 0u);
+        }
+      }
+
+      // Visit the fields.
+      for (FieldDecl *FD : RD->fields()) {
+        unsigned FDBW = 0;
+        if (FD->isUnnamedBitField())
+          continue;
+        if (FD->isBitField()) {
+          FDBW = FD->getBitWidthValue();
+        }
+
+        ReverseList.emplace_back(&Res->getStructField(FD->getFieldIndex()),
+                                 FD->getType(), FDBW);
+      }
+
+      std::reverse(ReverseList.begin(), ReverseList.end());
+      llvm::append_range(WorkList, ReverseList);
+      continue;
+    }
+    return false;
+  }
+  return true;
+}
+
+static bool handleElementwiseCast(EvalInfo &Info, const Expr *E,
+                                  const FPOptions FPO,
+                                  SmallVectorImpl<APValue> &Elements,
+                                  SmallVectorImpl<QualType> &SrcTypes,
+                                  SmallVectorImpl<QualType> &DestTypes,
+                                  SmallVectorImpl<APValue> &Results) {
+
+  assert((Elements.size() == SrcTypes.size()) &&
+         (Elements.size() == DestTypes.size()));
+
+  for (unsigned I = 0, ESz = Elements.size(); I < ESz; ++I) {
+    APValue Original = Elements[I];
+    QualType SourceTy = SrcTypes[I];
+    QualType DestTy = DestTypes[I];
+
+    if (!handleScalarCast(Info, FPO, E, SourceTy, DestTy, Original, Results[I]))
+      return false;
+  }
+  return true;
+}
+
+static unsigned elementwiseSize(EvalInfo &Info, QualType BaseTy) {
+
+  SmallVector<QualType> WorkList = {BaseTy};
+
+  unsigned Size = 0;
+  while (!WorkList.empty()) {
+    QualType Type = WorkList.pop_back_val();
+    if (Type->isRealFloatingType() || Type->isIntegerType() ||
+        Type->isBooleanType()) {
+      ++Size;
+      continue;
+    }
+    if (Type->isVectorType()) {
+      unsigned NumEl = Type->castAs<VectorType>()->getNumElements();
+      Size += NumEl;
+      continue;
+    }
+    if (Type->isConstantArrayType()) {
+      QualType ElTy = cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))
+                          ->getElementType();
+      uint64_t Size =
+          cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))->getZExtSize();
+      for (uint64_t I = 0; I < Size; ++I) {
+        WorkList.push_back(ElTy);
+      }
+      continue;
+    }
+    if (Type->isRecordType()) {
+      const RecordDecl *RD = Type->getAsRecordDecl();
+      // const ASTRecordLayout &Layout = Info.Ctx.getASTRecordLayout(RD);
+
+      // Visit the base classes.
+      if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+        // todo assert there is only 1 base at most
+        for (size_t I = 0, E = CXXRD->getNumBases(); I != E; ++I) {
+          const CXXBaseSpecifier &BS = CXXRD->bases_begin()[I];
+          WorkList.push_back(BS.getType());
+        }
+      }
+
+      // visit the fields.
+      for (FieldDecl *FD : RD->fields()) {
+        if (FD->isUnnamedBitField())
+          continue;
+        WorkList.push_back(FD->getType());
+      }
+      continue;
+    }
+  }
+  return Size;
+}
+
+static bool flattenAPValue(const ASTContext &Ctx, APValue Value,
+                           QualType BaseTy, SmallVectorImpl<APValue> &Elements,
+                           SmallVectorImpl<QualType> &Types, unsigned Size) {
+
+  SmallVector<std::pair<APValue, QualType>> WorkList = {{Value, BaseTy}};
+  unsigned Populated = 0;
+  while (!WorkList.empty() && Populated < Size) {
+    auto [Work, Type] = WorkList.pop_back_val();
+
+    if (Work.isFloat() || Work.isInt()) { // todo what does this do with bool
+      Elements.push_back(Work);
+      Types.push_back(Type);
+      Populated++;
+      continue;
+    }
+    if (Work.isVector()) {
+      assert(Type->isVectorType() && "Type mismatch.");
+      QualType ElTy = Type->castAs<VectorType>()->getElementType();
+      for (unsigned I = 0; I < Work.getVectorLength() && Populated < Size;
+           I++) {
+        Elements.push_back(Work.getVectorElt(I));
+        Types.push_back(ElTy);
+        Populated++;
+      }
+      continue;
+    }
+    if (Work.isArray()) {
+      assert(Type->isConstantArrayType() && "Type mismatch.");
+      QualType ElTy =
+          cast<ConstantArrayType>(Ctx.getAsArrayType(Type))->getElementType();
+      for (int64_t I = Work.getArraySize() - 1; I > -1; --I) {
+        WorkList.emplace_back(Work.getArrayInitializedElt(I), ElTy);
+      }
+      continue;
+    }
+
+    if (Work.isStruct()) {
+      assert(Type->isRecordType() && "Type mismatch.");
+
+      const RecordDecl *RD = Type->getAsRecordDecl();
+
+      SmallVector<std::pair<APValue, QualType>> ReverseList;
+      // Visit the fields.
+      for (FieldDecl *FD : RD->fields()) {
+        if (FD->isUnnamedBitField())
+          continue;
+        // if (FD->isBitField()) {
+        ReverseList.emplace_back(Work.getStructField(FD->getFieldIndex()),
+                                 FD->getType());
+      }
+
+      std::reverse(ReverseList.begin(), ReverseList.end());
+      llvm::append_range(WorkList, ReverseList);
+
+      // Visit the base classes.
+      if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+        if (CXXRD->getNumBases() > 0) {
+          assert(CXXRD->getNumBases() == 1);
+          const CXXBaseSpecifier &BS = CXXRD->bases_begin()[0];
+          const APValue &Base = Work.getStructBase(0);
+
+          // Can happen in error cases.
+          if (!Base.isStruct())
+            return false;
+
+          WorkList.emplace_back(Base, BS.getType());
+        }
+      }
+      continue;
+    }
+    return false;
+  }
+  return true;
+}
+
 namespace {
 /// A handle to a complete object (an object that is not a subobject of
 /// another object).
@@ -8666,6 +8993,25 @@ class ExprEvaluatorBase
     case CK_UserDefinedConversion:
       return StmtVisitorTy::Visit(E->getSubExpr());
 
+    case CK_HLSLArrayRValue: {
+      const Expr *SubExpr = E->getSubExpr();
+      if (!SubExpr->isGLValue()) {
+        APValue Val;
+        if (!Evaluate(Val, Info, SubExpr))
+          return false;
+        return DerivedSuccess(Val, E);
+      }
+
+      LValue LVal;
+      if (!EvaluateLValue(SubExpr, LVal, Info))
+        return false;
+      APValue RVal;
+      // Note, we use the subexpression's type in order to retain cv-qualifiers.
+      if (!handleLValueToRValueConversion(Info, E, SubExpr->getType(), LVal,
+                                          RVal))
+        return false;
+      return DerivedSuccess(RVal, E);
+    }
     case CK_LValueToRValue: {
       LValue LVal;
       if (!EvaluateLValue(E->getSubExpr(), LVal, Info))
@@ -10850,6 +11196,67 @@ bool RecordExprEvaluator::VisitCastExpr(const CastExpr *E) {
     Result = *Value;
     return true;
   }
+  case CK_HLSLAggregateSplatCast: {
+    APValue Val;
+    const Expr *SE = E->getSubExpr();
+
+    if (!Evaluate(Val, Info, SE))
+      return Error(E);
+
+    unsigned NEls = elementwiseSize(Info, E->getType());
+    // flatten the source
+    SmallVector<APValue, 1> SrcEls;
+    SmallVector<QualType, 1> SrcTypes;
+    if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes, NEls))
+      return Error(E);
+
+    // check there is only one and splat it
+    assert(SrcEls.size() == 1);
+    SmallVector<APValue> SplatEls(NEls, SrcEls[0]);
+    SmallVector<QualType> SplatType(NEls, SrcTypes[0]);
+
+    APValue Tmp;
+    handleDefaultInitValue(E->getType(), Tmp);
+
+    // cast the elements and construct our struct result
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    if (!constructAggregate(Info, FPO, E, Result, E->getType(), SplatEls,
+                            SplatType))
+      return Error(E);
+
+    return true;
+  }
+  case CK_HLSLElementwiseCast: {
+    APValue Val;
+    const Expr *SE = E->getSubExpr();
+
+    if (!Evaluate(Val, Info, SE))
+      return Error(E);
+
+    // must be dealing with a record;
+    if (Val.isLValue()) {
+      LValue LVal;
+      LVal.setFrom(Info.Ctx, Val);
+      if (!handleLValueToRValueConversion(Info, SE, SE->getType(), LVal, Val))
+        return false;
+    }
+
+    // flatten the source
+    SmallVector<APValue> SrcEls;
+    SmallVector<QualType> SrcTypes;
+    if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes,
+                        UINT_MAX))
+      return Error(E);
+
+    // cast the elements and construct our struct result
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+
+    if (!constructAggregate(Info, FPO, E, Result, E->getType(), SrcEls,
+                            SrcTypes))
+      return Error(E);
+
+    return true;
+  }
   }
 }
 
@@ -11345,6 +11752,58 @@ bool VectorExprEvaluator::VisitCastExpr(const CastExpr *E) {
       Elements.push_back(Val.getVectorElt(I));
     return Success(Elements, E);
   }
+  case CK_HLSLAggregateSplatCast: {
+    APValue Val;
+
+    if (!Evaluate(Val, Info, SE))
+      return Error(E);
+
+    // this cast doesn't handle splatting from scalars when result is a vector
+    SmallVector<APValue, 1> Elements;
+    SmallVector<QualType, 1> DestTypes = {VTy->getElementType()};
+    SmallVector<QualType, 1> SrcTypes;
+    if (!flattenAPValue(Info.Ctx, Val, SETy, Elements, SrcTypes, NElts))
+      return Error(E);
+
+    // check there is only one element and cast and splat it
+    assert(Elements.size() == 1 &&
+           "HLSLAggregateSplatCast RHS must contain one element");
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    SmallVector<APValue, 1> ResultEls(1);
+    if (!handleElementwiseCast(Info, E, FPO, Elements, SrcTypes, DestTypes,
+                               ResultEls))
+      return Error(E);
+
+    SmallVector<APValue, 4> SplatEls(NElts, ResultEls[0]);
+    return Success(SplatEls, E);
+  }
+  case CK_HLSLElementwiseCast: {
+    APValue Val;
+
+    if (!Evaluate(Val, Info, SE))
+      return Error(E);
+
+    // must be dealing with a record;
+    if (Val.isLValue()) {
+      LValue LVal;
+      LVal.setFrom(Info.Ctx, Val);
+      if (!handleLValueToRValueConversion(Info, SE, SE->getType(), LVal, Val))
+        return false;
+    }
+
+    SmallVector<APValue, 4> Elements;
+    SmallVector<QualType, 4> DestTypes(NElts, VTy->getElementType());
+    SmallVector<QualType, 4> SrcTypes;
+    if (!flattenAPValue(Info.Ctx, Val, SETy, Elements, SrcTypes, NElts))
+      return Error(E);
+    // cast elements
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    SmallVector<APValue, 4> ResultEls(NElts);
+    if (!handleElementwiseCast(Info, E, FPO, Elements, SrcTypes, DestTypes,
+                               ResultEls))
+      return Error(E);
+    return Success(ResultEls, E);
+  }
   default:
     return ExprEvaluatorBaseTy::VisitCastExpr(E);
   }
@@ -13029,6 +13488,7 @@ namespace {
     bool VisitCallExpr(const CallExpr *E) {
       return handleCallExpr(E, Result, &This);
     }
+    bool VisitCastExpr(const CastExpr *E);
     bool VisitInitListExpr(const InitListExpr *E,
                            QualType AllocType = QualType());
     bool VisitArrayInitLoopExpr(const ArrayInitLoopExpr *E);
@@ -13099,6 +13559,70 @@ static bool MaybeElementDependentArrayFiller(const Expr *FillerExpr) {
   return true;
 }
 
+bool ArrayExprEvaluator::VisitCastExpr(const CastExpr *E) {
+  const Expr *SE = E->getSubExpr();
+
+  switch (E->getCastKind()) {
+  default:
+    return ExprEvaluatorBaseTy::VisitCastExpr(E);
+  case CK_HLSLAggregateSplatCast: {
+    APValue Val;
+
+    if (!Evaluate(Val, Info, SE))
+      return Error(E);
+
+    unsigned NEls = elementwiseSize(Info, E->getType());
+    // flatten the source
+    SmallVector<APValue, 1> SrcEls;
+    SmallVector<QualType, 1> SrcTypes;
+    if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes, NEls))
+      return Error(E);
+
+    // check there is only one and splat it
+    assert(SrcEls.size() == 1);
+    SmallVector<APValue> SplatEls(NEls, SrcEls[0]);
+    SmallVector<QualType> SplatType(NEls, SrcTypes[0]);
+
+    // cast the elements
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    if (!constructAggregate(Info, FPO, E, Result, E->getType(), SplatEls,
+                            SplatType))
+      return Error(E);
+
+    return true;
+  }
+  case CK_HLSLElementwiseCast: {
+    APValue Val;
+
+    if (!Evaluate(Val, Info, SE))
+      return Error(E);
+
+    // must be dealing with a record;
+    if (Val.isLValue()) {
+      LValue LVal;
+      LVal.setFrom(Info.Ctx, Val);
+      if (!handleLValueToRValueConversion(Info, SE, SE->getType(), LVal, Val))
+        return false;
+    }
+
+    // flatten the source
+    SmallVector<APValue> SrcEls;
+    SmallVector<QualType> SrcTypes;
+    if (!flattenAPValue(Info.Ctx, Val, SE->getType(), SrcEls, SrcTypes,
+                        UINT_MAX))
+      return Error(E);
+
+    // cast the elements
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    if (!constructAggregate(Info, FPO, E, Result, E->getType(), SrcEls,
+                            SrcTypes))
+      return Error(E);
+
+    return true;
+  }
+  }
+}
+
 bool ArrayExprEvaluator::VisitInitListExpr(const InitListExpr *E,
                                            QualType AllocType) {
   const ConstantArrayType *CAT = Info.Ctx.getAsConstantArrayType(
@@ -16801,7 +17325,6 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
   case CK_NoOp:
   case CK_LValueToRValueBitCast:
   case CK_HLSLArrayRValue:
-  case CK_HLSLElementwiseCast:
     return ExprEvaluatorBaseTy::VisitCastExpr(E);
 
   case CK_MemberPointerToBoolean:
@@ -16948,6 +17471,35 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
       return Error(E);
     return Success(Val.getVectorElt(0), E);
   }
+  case CK_HLSLElementwiseCast: {
+    APValue Val;
+
+    if (!Evaluate(Val, Info, SubExpr))
+      return Error(E);
+
+    // must be dealing with a record;
+    if (Val.isLValue()) {
+      LValue LVal;
+      LVal.setFrom(Info.Ctx, Val);
+      if (!handleLValueToRValueConversion(Info, SubExpr, SubExpr->getType(),
+                                          LVal, Val))
+        return false;
+    }
+
+    SmallVector<APValue, 1> Elements;
+    SmallVector<QualType, 1> DestTypes(1, DestType);
+    SmallVector<QualType, 1> SrcTypes;
+    if (!flattenAPValue(Info.Ctx, Val, SrcType, Elements, SrcTypes, 1))
+      return Error(E);
+
+    // cast our single element
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    APValue ResultVal;
+    if (!handleScalarCast(Info, FPO, E, SrcTypes[0], DestTypes[0], Elements[0],
+                          ResultVal))
+      return Error(E);
+    return Success(ResultVal, E);
+  }
   }
 
   llvm_unreachable("unknown cast resulting in integral value");
@@ -17485,6 +18037,9 @@ bool FloatExprEvaluator::VisitCastExpr(const CastExpr *E) {
   default:
     return ExprEvaluatorBaseTy::VisitCastExpr(E);
 
+  case CK_HLSLAggregateSplatCast:
+    llvm_unreachable("invalid cast kind for floating value");
+
   case CK_IntegralToFloating: {
     APSInt IntResult;
     const FPOptions FPO = E->getFPFeaturesInEffect(
@@ -17523,6 +18078,36 @@ bool FloatExprEvaluator::VisitCastExpr(const CastExpr *E) {
       return Error(E);
     return Success(Val.getVectorElt(0), E);
   }
+  case CK_HLSLElementwiseCast: {
+    APValue Val;
+
+    if (!Evaluate(Val, Info, SubExpr))
+      return Error(E);
+
+    // must be dealing with a record;
+    if (Val.isLValue()) {
+      LValue LVal;
+      LVal.setFrom(Info.Ctx, Val);
+      if (!handleLValueToRValueConversion(Info, SubExpr, SubExpr->getType(),
+                                          LVal, Val))
+        return false;
+    }
+
+    SmallVector<APValue, 1> Elements;
+    SmallVector<QualType, 1> DestTypes(1, E->getType());
+    SmallVector<QualType, 1> SrcTypes;
+    if (!flattenAPValue(Info.Ctx, Val, SubExpr->getType(), Elements, SrcTypes,
+                        1))
+      return Error(E);
+
+    // cast our single element
+    const FPOptions FPO = E->getFPFeaturesInEffect(Info.Ctx.getLangOpts());
+    APValue ResultVal;
+    if (!handleScalarCast(Info, FPO, E, SrcTypes[0], DestTypes[0], Elements[0],
+                          ResultVal))
+      return Error(E);
+    return Success(ResultVal, E);
+  }
   }
 }
 
diff --git a/clang/test/SemaHLSL/Types/AggregateSplatConstantExpr.hlsl b/clang/test/SemaHLSL/Types/AggregateSplatConstantExpr.hlsl
new file mode 100644
index 0000000000000..7df41f24ee0d9
--- /dev/null
+++ b/clang/test/SemaHLSL/Types/AggregateSplatConstantExpr.hlsl
@@ -0,0 +1,89 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -fnative-half-type -std=hlsl202x -verify %s
+
+// expected-no-diagnostics
+
+struct Base {
+  double D;
+  uint64_t2 U;
+  int16_t I : 5;
+  uint16_t I2: 5;
+};
+
+struct R : Base {
+  int G : 10;
+  int : 30;
+  float F;
+};
+
+struct B1 {
+  float A;
+  float B;
+};
+
+struct B2 : B1 {
+  int C;
+  int D;
+  bool BB;
+};
+
+// tests for HLSLAggregateSplatCast
+export void fn() {
+  // result type vector
+  // splat from a vector of size 1
+  
+  constexpr float1 Y = {1.0};
+  constexpr int1 A1 = {1};
+  constexpr float4 F4 = (float4)Y;
+  _Static_assert(F4[0] == 1.0, "Woo!");
+  _Static_assert(F4[1] == 1.0, "Woo!");
+  _Static_assert(F4[2] == 1.0, "Woo!");
+  _Static_assert(F4[3] == 1.0, "Woo!");
+
+  // result type array
+  // splat from a scalar
+  constexpr float F = 3.33;
+  constexpr int B6[6] = (int[6])F;
+  _Static_assert(B6[0] == 3, "Woo!");
+  _Static_assert(B6[1] == 3, "Woo!");
+  _Static_assert(B6[2] == 3, "Woo!");
+  _Static_assert(B6[3] == 3, "Woo!");
+  _Static_assert(B6[4] == 3, "Woo!");
+  _Static_assert(B6[5] == 3, "Woo!");
+
+  // splat from a vector of size 1
+  constexpr uint64_t2 A7[2] = (uint64_t2[2])A1;
+  _Static_assert(A7[0][0] == 1, "Woo!");
+  _Static_assert(A7[0][1] == 1, "Woo!");
+  _Static_assert(A7[1][0] == 1, "Woo!");
+  _Static_assert(A7[1][1] == 1, "Woo!");
+
+  // result type struct
+  // splat from a scalar
+  constexpr double D = 100.6789;
+  constexpr R SR = (R)D;
+  _Static_assert(SR.D == 100.6789, "Woo!");
+  _Static_assert(SR.U[0] == 100, "Woo!");
+  _Static_assert(SR.U[1] == 100, "Woo!");
+  _Static_assert(SR.I == 4, "Woo!");
+  _Static_assert(SR.I2 == 4, "Woo!");
+  _Static_assert(SR.G == 100, "Woo!");
+  _Static_assert(SR.F == 100.6789, "Woo!");
+
+  // splat from a vector of size 1
+  constexpr float1 A100 = {1000.1111};
+  constexpr B2 SB2 = (B2)A100;
+  _Static_assert(SB2.A == 1000.1111, "Woo!");
+  _Static_assert(SB2.B == 1000.1111, "Woo!");
+  _Static_assert(SB2.C == 1000, "Woo!");
+  _Static_assert(SB2.D == 1000, "Woo!");
+  _Static_assert(SB2.BB == true, "Woo!");
+
+  // splat from a bool to an int and float etc
+  constexpr bool B = true;
+  constexpr B2 SB3 = (B2)B;
+  _Static_assert(SB3.A == 1.0, "Woo!");
+  _Static_assert(SB3.B == 1.0, "Woo!");
+  _Static_assert(SB3.C == 1, "Woo!");
+  _Static_assert(SB3.D == 1, "Woo!");
+  _Static_assert(SB3.BB == true, "Woo!");
+}
diff --git a/clang/test/SemaHLSL/Types/BuiltinVector/TruncationConstantExpr.hlsl b/clang/test/SemaHLSL/Types/BuiltinVector/TruncationConstantExpr.hlsl
index 918daa03d8032..85b7271fd2c8c 100644
--- a/clang/test/SemaHLSL/Types/BuiltinVector/TruncationConstantExpr.hlsl
+++ b/clang/test/SemaHLSL/Types/BuiltinVector/TruncationConstantExpr.hlsl
@@ -5,7 +5,28 @@
 // Note: these tests are a bit awkward because at time of writing we don't have a
 // good way to constexpr `any` for bool vector conditions, and the condition for
 // _Static_assert must be an integral constant.
+
+struct S {
+  int3 A;
+  float B;
+};
+
 export void fn() {
+
+  _Static_assert(((float4)(int[6]){1,2,3,4,5,6}).x == 1.0, "Woo!");
+
+  // This compiling successfully verifies that the array constant expression
+  // gets truncated to a float at compile time for instantiation via the
+  // flat cast
+  _Static_assert(((int)(int[2]){1,2}) == 1, "Woo!");
+
+  // This compiling successfully verifies that the struct constant expression
+  // gets truncated to an integer at compile time for instatiation via the
+  // flat cast
+  _Static_assert(((int)(S){{1,2,3},1.0}) == 1, "Woo!");
+
+  _Static_assert(((float)(float[2]){1.0,2.0}) == 1.0, "Woo!");
+
   // This compiling successfully verifies that the vector constant expression
   // gets truncated to an integer at compile time for instantiation.
   _Static_assert(((int)1.xxxx) + 0 == 1, "Woo!");
diff --git a/clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl b/clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl
new file mode 100644
index 0000000000000..6f697e8097a21
--- /dev/null
+++ b/clang/test/SemaHLSL/Types/ElementwiseCastConstantExpr.hlsl
@@ -0,0 +1,76 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -fnative-half-type -std=hlsl202x -verify %s
+
+// expected-no-diagnostics
+
+struct Base {
+  double D;
+  uint64_t2 U;
+  int16_t I : 5;
+  uint16_t I2: 5;
+};
+
+struct R : Base {
+  int G : 10;
+  int : 30;
+  float F;
+};
+
+struct B1 {
+  float A;
+  float B;
+};
+
+struct B2 : B1 {
+  int C;
+  int D;
+  bool BB;
+};
+
+export void fn() {
+
+  // truncation tests
+  // result type int
+  // truncate from struct
+  constexpr B1 SB1 = {1.0, 3.0};
+  constexpr float Blah = SB1.A;
+  constexpr int X = (int)SB1;
+  _Static_assert(X == 1, "Woo!");
+
+  // result type float
+  // truncate from array
+  constexpr B1 Arr[2] = {4.0, 3.0, 2.0, 1.0};
+  constexpr float F = (float)Arr;
+  _Static_assert(F == 4.0, "Woo!");
+
+  // result type vector
+  // truncate from array of vector
+  constexpr int2 Arr2[2] = {5,6,7,8};
+  constexpr int2 I2 = (int2)Arr2;
+  _Static_assert(I2[0] == 5, "Woo!");
+  _Static_assert(I2[1] == 6, "Woo!");
+
+  // lhs and rhs are same "size" tests
+  
+  // result type vector from  array
+  constexpr int4 I4 = (int4)Arr;
+  _Static_assert(I4[0] == 4, "Woo!");
+  _Static_assert(I4[1] == 3, "Woo!");
+  _Static_assert(I4[2] == 2, "Woo!");
+  _Static_assert(I4[3] == 1, "Woo!");
+
+  // result type array from vector
+  constexpr double3 D3 = {100.11, 200.11, 300.11};
+  constexpr float FArr[3] = (float[3])D3;
+  _Static_assert(FArr[0] == 100.11, "Woo!");
+  _Static_assert(FArr[1] == 200.11, "Woo!");
+  _Static_assert(FArr[2] == 300.11, "Woo!");
+
+  // result type struct from struct
+  constexpr B2 SB2 = {5.5, 6.5, 1000, 5000, false};
+  constexpr Base SB = (Base)SB2;
+  _Static_assert(SB.D == 5.5, "Woo!");
+  _Static_assert(SB.U[0] == 6, "Woo!");
+  _Static_assert(SB.U[1] == 1000, "Woo!");
+  _Static_assert(SB.I == 8, "Woo!");
+  _Static_assert(SB.I2 == 0, "Woo!");
+}



More information about the cfe-commits mailing list