[clang] [HLSL] Implement HLSL splatting (PR #118992)

Sarah Spall via cfe-commits cfe-commits at lists.llvm.org
Fri Dec 6 09:41:29 PST 2024


https://github.com/spall updated https://github.com/llvm/llvm-project/pull/118992

>From 2e932a57ccb992b856b58bec4c30c6b64f24f711 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Thu, 28 Nov 2024 16:23:57 +0000
Subject: [PATCH 1/8] Flat casts WIP

---
 clang/include/clang/AST/OperationKinds.def    |   3 +
 clang/include/clang/Sema/SemaHLSL.h           |   2 +
 clang/lib/AST/Expr.cpp                        |   1 +
 clang/lib/AST/ExprConstant.cpp                |   1 +
 clang/lib/CodeGen/CGExpr.cpp                  |  84 ++++++++++
 clang/lib/CodeGen/CGExprAgg.cpp               |  83 +++++++++-
 clang/lib/CodeGen/CGExprComplex.cpp           |   1 +
 clang/lib/CodeGen/CGExprConstant.cpp          |   1 +
 clang/lib/CodeGen/CGExprScalar.cpp            |  39 +++++
 clang/lib/CodeGen/CodeGenFunction.h           |   7 +
 clang/lib/Edit/RewriteObjCFoundationAPI.cpp   |   1 +
 clang/lib/Sema/Sema.cpp                       |   1 +
 clang/lib/Sema/SemaCast.cpp                   |  20 ++-
 clang/lib/Sema/SemaHLSL.cpp                   | 143 ++++++++++++++++++
 clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp |   1 +
 15 files changed, 384 insertions(+), 4 deletions(-)

diff --git a/clang/include/clang/AST/OperationKinds.def b/clang/include/clang/AST/OperationKinds.def
index 8788b8ff0ef0a4..9323d4e861a734 100644
--- a/clang/include/clang/AST/OperationKinds.def
+++ b/clang/include/clang/AST/OperationKinds.def
@@ -367,6 +367,9 @@ CAST_OPERATION(HLSLVectorTruncation)
 // Non-decaying array RValue cast (HLSL only).
 CAST_OPERATION(HLSLArrayRValue)
 
+// Aggregate by Value cast (HLSL only).
+CAST_OPERATION(HLSLAggregateCast)
+
 //===- Binary Operations  -------------------------------------------------===//
 // Operators listed in order of precedence.
 // Note that additions to this should also update the StmtVisitor class,
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index ee685d95c96154..6bda1e8ce0ea5b 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -140,6 +140,8 @@ class SemaHLSL : public SemaBase {
   // Diagnose whether the input ID is uint/unit2/uint3 type.
   bool diagnoseInputIDType(QualType T, const ParsedAttr &AL);
 
+  bool CanPerformScalarCast(QualType SrcTy, QualType DestTy);
+  bool CanPerformAggregateCast(Expr *Src, QualType DestType);
   ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg);
 
   QualType getInoutParameterType(QualType Ty);
diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp
index a4fb4d5a1f2ec4..4764bc84ce498a 100644
--- a/clang/lib/AST/Expr.cpp
+++ b/clang/lib/AST/Expr.cpp
@@ -1942,6 +1942,7 @@ bool CastExpr::CastConsistency() const {
   case CK_FixedPointToBoolean:
   case CK_HLSLArrayRValue:
   case CK_HLSLVectorTruncation:
+  case CK_HLSLAggregateCast:
   CheckNoBasePath:
     assert(path_empty() && "Cast kind should not have a base path!");
     break;
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 6b5b95aee35522..b548cef41b7525 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -15733,6 +15733,7 @@ bool ComplexExprEvaluator::VisitCastExpr(const CastExpr *E) {
   case CK_IntegralToFixedPoint:
   case CK_MatrixCast:
   case CK_HLSLVectorTruncation:
+  case CK_HLSLAggregateCast:
     llvm_unreachable("invalid cast kind for complex value");
 
   case CK_LValueToRValue:
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 5fccc9cbb37ec1..b7608b1226758d 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -5320,6 +5320,7 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) {
   case CK_MatrixCast:
   case CK_HLSLVectorTruncation:
   case CK_HLSLArrayRValue:
+  case CK_HLSLAggregateCast:
     return EmitUnsupportedLValue(E, "unexpected cast lvalue");
 
   case CK_Dependent:
@@ -6358,3 +6359,86 @@ RValue CodeGenFunction::EmitPseudoObjectRValue(const PseudoObjectExpr *E,
 LValue CodeGenFunction::EmitPseudoObjectLValue(const PseudoObjectExpr *E) {
   return emitPseudoObjectExpr(*this, E, true, AggValueSlot::ignored()).LV;
 }
+
+llvm::Value* CodeGenFunction::PerformLoad(std::pair<Address, llvm::Value *> &GEP) {
+  Address GEPAddress = GEP.first;
+  llvm::Value *Idx = GEP.second;
+  llvm::Value *V = Builder.CreateLoad(GEPAddress, "load");
+  if (Idx) { // loading from a vector so perform an extract as well
+    return Builder.CreateExtractElement(V, Idx, "vec.load");
+  }
+  return V;
+}
+
+llvm::Value* CodeGenFunction::PerformStore(std::pair<Address, llvm::Value *> &GEP,
+				           llvm::Value *Val) {
+  Address GEPAddress = GEP.first;
+  llvm::Value *Idx = GEP.second;
+  if (Idx) {
+    llvm::Value *V = Builder.CreateLoad(GEPAddress, "load.for.insert");
+    return Builder.CreateInsertElement(V, Val, Idx);
+  } else {
+    return Builder.CreateStore(Val, GEPAddress);
+  }
+}
+
+void CodeGenFunction::FlattenAccessAndType(Address Val, QualType SrcTy,
+			         SmallVector<llvm::Value *, 4> &IdxList,
+			         SmallVector<std::pair<Address, llvm::Value *>, 16> &GEPList,
+				 SmallVector<QualType> &FlatTypes) {
+  llvm::IntegerType *IdxTy = llvm::IntegerType::get(getLLVMContext(),32);
+  if (const ConstantArrayType *CAT = dyn_cast<ConstantArrayType>(SrcTy)) {
+    uint64_t Size = CAT->getZExtSize();
+    for(unsigned i = 0; i < Size; i ++) {
+      // flatten each member of the array
+      // add index of this element to index list
+      llvm::Value *Idx = llvm::ConstantInt::get(IdxTy, i);
+      IdxList.push_back(Idx);
+      // recur on this object
+      FlattenAccessAndType(Val, CAT->getElementType(), IdxList, GEPList, FlatTypes);
+      // remove index of this element from index list
+      IdxList.pop_back();
+    }
+  } else if (const RecordType *RT = SrcTy->getAs<RecordType>()) {
+    RecordDecl *Record = RT->getDecl();
+    const CGRecordLayout &RL = getTypes().getCGRecordLayout(Record);
+    // do I need to check if its a cxx record decl?
+
+    for (auto fieldIter = Record->field_begin(), fieldEnd = Record->field_end();
+	 fieldIter != fieldEnd; ++fieldIter) {
+      // get the field number
+      unsigned FieldNum = RL.getLLVMFieldNo(*fieldIter);
+      // can we just do *fieldIter->getFieldIndex();
+      // add that index to the index list
+      llvm::Value *Idx = llvm::ConstantInt::get(IdxTy, FieldNum);
+      IdxList.push_back(Idx);
+      // recur on the field
+      FlattenAccessAndType(Val, fieldIter->getType(), IdxList, GEPList,
+			   FlatTypes);
+      // remove index of this element from index list
+      IdxList.pop_back();
+    }
+  } else if (const VectorType *VT = SrcTy->getAs<VectorType>()) {
+    llvm::Type *VTy = ConvertTypeForMem(SrcTy);
+    CharUnits Align = getContext().getTypeAlignInChars(SrcTy);
+    Address GEP = Builder.CreateInBoundsGEP(Val, IdxList,
+						 VTy, Align, "vector.gep");
+    for(unsigned i = 0; i < VT->getNumElements(); i ++) {
+      // add index to the list
+      llvm::Value *Idx = llvm::ConstantInt::get(IdxTy, i);
+      // create gep. no need to recur since its always a scalar
+      // gep on vector is not recommended so combine gep with extract/insert
+      GEPList.push_back({GEP, Idx});
+      FlatTypes.push_back(VT->getElementType());
+    }
+  } else { // should be a scalar should we assert or check?
+    // create a gep
+    llvm::Type *Ty = ConvertTypeForMem(SrcTy);
+    CharUnits Align = getContext().getTypeAlignInChars(SrcTy);
+    Address GEP = Builder.CreateInBoundsGEP(Val, IdxList,
+						     Ty, Align,  "gep");
+    GEPList.push_back({GEP, NULL});
+    FlatTypes.push_back(SrcTy);
+  }
+  // target extension types?
+}
diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index 2ad6587089f101..bc8e1f0f9248ef 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -491,6 +491,70 @@ static bool isTrivialFiller(Expr *E) {
   return false;
 }
 
+
+
+// emit a flat cast where the RHS is a scalar, including vector
+static void EmitHLSLScalarFlatCast(CodeGenFunction &CGF, Address DestVal,
+			    QualType DestTy, llvm::Value *SrcVal,
+			    QualType SrcTy, SourceLocation Loc) {
+  // Flatten our destination
+  SmallVector<QualType> DestTypes; // Flattened type
+  SmallVector<llvm::Value *, 4> IdxList;
+  SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
+  // ^^ Flattened accesses to DestVal we want to store into
+  CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList,
+		       DestTypes);
+
+  if (const VectorType *VT = SrcTy->getAs<VectorType>()) {
+    SrcTy = VT->getElementType();
+    assert(StoreGEPList.size() <= VT->getNumElements() &&
+	   "Cannot perform HLSL flat cast when vector source \
+           object has less elements than flattened destination \
+           object.");
+      for(unsigned i = 0; i < StoreGEPList.size(); i ++) {
+        llvm::Value *Load = CGF.Builder.CreateExtractElement(SrcVal, i,
+							     "vec.load");
+	llvm::Value *Cast = CGF.EmitScalarConversion(Load, SrcTy,
+						     DestTypes[i],
+						     Loc);
+	CGF.PerformStore(StoreGEPList[i], Cast);
+      }
+      return;
+  }
+  llvm_unreachable("HLSL Flat cast doesn't handle splatting.");
+}
+
+// emit a flat cast where the RHS is an aggregate
+static void EmitHLSLAggregateFlatCast(CodeGenFunction &CGF, Address DestVal,
+			       QualType DestTy, Address SrcVal,
+			       QualType SrcTy, SourceLocation Loc) {
+  // Flatten our destination
+  SmallVector<QualType> DestTypes; // Flattened type
+  SmallVector<llvm::Value *, 4> IdxList;
+  SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
+  // ^^ Flattened accesses to DestVal we want to store into
+  CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList,
+		       DestTypes);
+  // Flatten our src
+  SmallVector<QualType> SrcTypes; // Flattened type
+  SmallVector<std::pair<Address, llvm::Value *>, 16> LoadGEPList;
+  // ^^ Flattened accesses to SrcVal we want to load from
+  IdxList.clear();
+  CGF.FlattenAccessAndType(SrcVal, SrcTy, IdxList, LoadGEPList, SrcTypes);
+
+  assert(StoreGEPList.size() <= LoadGEPList.size() &&
+	 "Cannot perform HLSL flat cast when flattened source object \
+          has less elements than flattened destination object.");
+  // apply casts to what we load from LoadGEPList
+  // and store result in Dest
+  for(unsigned i = 0; i < StoreGEPList.size(); i ++) {
+    llvm::Value *Load = CGF.PerformLoad(LoadGEPList[i]);
+    llvm::Value *Cast = CGF.EmitScalarConversion(Load, SrcTypes[i],
+						 DestTypes[i], Loc);
+    CGF.PerformStore(StoreGEPList[i], Cast);
+  }
+}
+
 /// Emit initialization of an array from an initializer list. ExprToVisit must
 /// be either an InitListEpxr a CXXParenInitListExpr.
 void AggExprEmitter::EmitArrayInit(Address DestPtr, llvm::ArrayType *AType,
@@ -890,7 +954,24 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
   case CK_HLSLArrayRValue:
     Visit(E->getSubExpr());
     break;
-
+  case CK_HLSLAggregateCast: {
+    Expr *Src = E->getSubExpr();
+    QualType SrcTy = Src->getType();
+    RValue RV = CGF.EmitAnyExpr(Src);
+    QualType DestTy = E->getType();
+    Address DestVal = Dest.getAddress();
+    SourceLocation Loc = E->getExprLoc();
+
+    if (RV.isScalar()) {
+      llvm::Value *SrcVal = RV.getScalarVal();
+      EmitHLSLScalarFlatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
+    } else { // RHS is an aggregate
+      assert(RV.isAggregate() &&
+	     "Can't perform HLSL Aggregate cast on a complex type.");
+      Address SrcVal = RV.getAggregateAddress();
+      EmitHLSLAggregateFlatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
+    }
+    break; }
   case CK_NoOp:
   case CK_UserDefinedConversion:
   case CK_ConstructorConversion:
diff --git a/clang/lib/CodeGen/CGExprComplex.cpp b/clang/lib/CodeGen/CGExprComplex.cpp
index ac31dff11b585e..05680d36aa2bd7 100644
--- a/clang/lib/CodeGen/CGExprComplex.cpp
+++ b/clang/lib/CodeGen/CGExprComplex.cpp
@@ -610,6 +610,7 @@ ComplexPairTy ComplexExprEmitter::EmitCast(CastKind CK, Expr *Op,
   case CK_MatrixCast:
   case CK_HLSLVectorTruncation:
   case CK_HLSLArrayRValue:
+  case CK_HLSLAggregateCast:
     llvm_unreachable("invalid cast kind for complex value");
 
   case CK_FloatingRealToComplex:
diff --git a/clang/lib/CodeGen/CGExprConstant.cpp b/clang/lib/CodeGen/CGExprConstant.cpp
index 655fc3dc954c81..6d15bc9058e450 100644
--- a/clang/lib/CodeGen/CGExprConstant.cpp
+++ b/clang/lib/CodeGen/CGExprConstant.cpp
@@ -1335,6 +1335,7 @@ class ConstExprEmitter
     case CK_MatrixCast:
     case CK_HLSLVectorTruncation:
     case CK_HLSLArrayRValue:
+    case CK_HLSLAggregateCast:
       return nullptr;
     }
     llvm_unreachable("Invalid CastKind");
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 4ae8a2b22b1bba..d7bb702ec3ca20 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -2262,6 +2262,35 @@ bool CodeGenFunction::ShouldNullCheckClassCastValue(const CastExpr *CE) {
   return true;
 }
 
+// RHS is an aggregate type
+static Value *EmitHLSLAggregateFlatCast(CodeGenFunction &CGF, Address RHSVal,
+					QualType RHSTy, QualType LHSTy,
+					SourceLocation Loc) {
+  SmallVector<llvm::Value *, 4> IdxList;
+  SmallVector<std::pair<Address, llvm::Value *>, 16> LoadGEPList;
+  SmallVector<QualType> SrcTypes; // Flattened type
+  CGF.FlattenAccessAndType(RHSVal, RHSTy, IdxList, LoadGEPList, SrcTypes);
+  // LHS is either a vector or a builtin?
+  // if its a vector create a temp alloca to store into and return that
+  if (auto *VecTy = LHSTy->getAs<VectorType>()) {
+    llvm::Value *V = CGF.Builder.CreateLoad(CGF.CreateIRTemp(LHSTy, "flatcast.tmp"));
+    // write to V.
+    for(unsigned i = 0; i < VecTy->getNumElements(); i ++) {
+      llvm::Value *Load = CGF.PerformLoad(LoadGEPList[i]);
+      llvm::Value *Cast = CGF.EmitScalarConversion(Load, SrcTypes[i],
+						   VecTy->getElementType(), Loc);
+      V = CGF.Builder.CreateInsertElement(V, Cast, i);
+    }
+    return V;
+  }
+  // i its a builtin just do an extract element or load.
+  assert(LHSTy->isBuiltinType() &&
+	 "Destination type must be a vector or builtin type.");
+  // TODO add asserts about things being long enough
+  return CGF.EmitScalarConversion(CGF.PerformLoad(LoadGEPList[0]),
+				  LHSTy, SrcTypes[0], Loc);
+}
+
 // VisitCastExpr - Emit code for an explicit or implicit cast.  Implicit casts
 // have to handle a more broad range of conversions than explicit casts, as they
 // handle things like function to ptr-to-function decay etc.
@@ -2752,7 +2781,17 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
     llvm::Value *Zero = llvm::Constant::getNullValue(CGF.SizeTy);
     return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc");
   }
+  case CK_HLSLAggregateCast: {
+    RValue RV = CGF.EmitAnyExpr(E);
+    SourceLocation Loc = CE->getExprLoc();
+    QualType SrcTy = E->getType();
 
+    if (RV.isAggregate()) { // RHS is an aggregate
+      Address SrcVal = RV.getAggregateAddress();
+      return EmitHLSLAggregateFlatCast(CGF, SrcVal, SrcTy, DestTy, Loc);
+    }
+    llvm_unreachable("Not a valid HLSL Flat Cast.");
+  }
   } // end of switch
 
   llvm_unreachable("unknown scalar cast");
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index eaea0d8a08ac06..b17ead377610e6 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4359,6 +4359,13 @@ class CodeGenFunction : public CodeGenTypeCache {
                                 AggValueSlot slot = AggValueSlot::ignored());
   LValue EmitPseudoObjectLValue(const PseudoObjectExpr *e);
 
+  llvm::Value *PerformLoad(std::pair<Address, llvm::Value *> &GEP);
+  llvm::Value *PerformStore(std::pair<Address, llvm::Value *> &GEP, llvm::Value *Val);
+  void FlattenAccessAndType(Address Val, QualType SrcTy,
+			    SmallVector<llvm::Value *, 4> &IdxList,
+			    SmallVector<std::pair<Address, llvm::Value *>, 16> &GEPList,
+			    SmallVector<QualType> &FlatTypes);
+
   llvm::Value *EmitIvarOffset(const ObjCInterfaceDecl *Interface,
                               const ObjCIvarDecl *Ivar);
   llvm::Value *EmitIvarOffsetAsPointerDiff(const ObjCInterfaceDecl *Interface,
diff --git a/clang/lib/Edit/RewriteObjCFoundationAPI.cpp b/clang/lib/Edit/RewriteObjCFoundationAPI.cpp
index 81797c8c4dc75a..63308319a78d1c 100644
--- a/clang/lib/Edit/RewriteObjCFoundationAPI.cpp
+++ b/clang/lib/Edit/RewriteObjCFoundationAPI.cpp
@@ -1085,6 +1085,7 @@ static bool rewriteToNumericBoxedExpression(const ObjCMessageExpr *Msg,
       llvm_unreachable("OpenCL-specific cast in Objective-C?");
 
     case CK_HLSLVectorTruncation:
+    case CK_HLSLAggregateCast:
       llvm_unreachable("HLSL-specific cast in Objective-C?");
       break;
 
diff --git a/clang/lib/Sema/Sema.cpp b/clang/lib/Sema/Sema.cpp
index d6517511d7db4d..2f0528d6ab5ce1 100644
--- a/clang/lib/Sema/Sema.cpp
+++ b/clang/lib/Sema/Sema.cpp
@@ -707,6 +707,7 @@ ExprResult Sema::ImpCastExprToType(Expr *E, QualType Ty,
     case CK_ToVoid:
     case CK_NonAtomicToAtomic:
     case CK_HLSLArrayRValue:
+    case CK_HLSLAggregateCast:
       break;
     }
   }
diff --git a/clang/lib/Sema/SemaCast.cpp b/clang/lib/Sema/SemaCast.cpp
index f98857f852b5af..955c44cf4a6a42 100644
--- a/clang/lib/Sema/SemaCast.cpp
+++ b/clang/lib/Sema/SemaCast.cpp
@@ -25,6 +25,7 @@
 #include "clang/Sema/Initialization.h"
 #include "clang/Sema/SemaObjC.h"
 #include "clang/Sema/SemaRISCV.h"
+#include "clang/Sema/SemaHLSL.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"
 #include <set>
@@ -2768,6 +2769,22 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
     return;
   }
 
+  CheckedConversionKind CCK = FunctionalStyle
+                                  ? CheckedConversionKind::FunctionalCast
+                                  : CheckedConversionKind::CStyleCast;
+  // todo what else should i be doing lvalue to rvalue cast for?
+  // why dont they do it for records below?
+  // This case should not trigger on regular vector splat
+  // Or vector cast or vector truncation.
+  QualType SrcTy = SrcExpr.get()->getType();
+  if (Self.getLangOpts().HLSL &&
+      Self.HLSL().CanPerformAggregateCast(SrcExpr.get(), DestType)) {
+    if (SrcTy->isConstantArrayType())
+      SrcExpr = Self.ImpCastExprToType(SrcExpr.get(), Self.Context.getArrayParameterType(SrcTy), CK_HLSLArrayRValue, VK_PRValue, nullptr, CCK);
+    Kind = CK_HLSLAggregateCast;
+    return;
+  }
+
   if (ValueKind == VK_PRValue && !DestType->isRecordType() &&
       !isPlaceholder(BuiltinType::Overload)) {
     SrcExpr = Self.DefaultFunctionArrayLvalueConversion(SrcExpr.get());
@@ -2820,9 +2837,6 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
   if (isValidCast(tcr))
     Kind = CK_NoOp;
 
-  CheckedConversionKind CCK = FunctionalStyle
-                                  ? CheckedConversionKind::FunctionalCast
-                                  : CheckedConversionKind::CStyleCast;
   if (tcr == TC_NotApplicable) {
     tcr = TryAddressSpaceCast(Self, SrcExpr, DestType, /*CStyle*/ true, msg,
                               Kind);
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 88db3e12541193..942c0a8fcaab09 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2412,6 +2412,149 @@ bool SemaHLSL::CheckCompatibleParameterABI(FunctionDecl *New,
   return HadError;
 }
 
+// Follows PerformScalarCast
+bool SemaHLSL::CanPerformScalarCast(QualType SrcTy, QualType DestTy) {
+
+  if (SemaRef.getASTContext().hasSameUnqualifiedType(SrcTy, DestTy))
+    return true;
+
+  switch (Type::ScalarTypeKind SrcKind = SrcTy->getScalarTypeKind()) {
+  case Type::STK_MemberPointer:
+    return false;
+
+  case Type::STK_CPointer:
+  case Type::STK_BlockPointer:
+  case Type::STK_ObjCObjectPointer:
+    switch (DestTy->getScalarTypeKind()) {
+    case Type::STK_CPointer:
+    case Type::STK_BlockPointer:
+    case Type::STK_ObjCObjectPointer:
+    case Type::STK_Bool:
+    case Type::STK_Integral:
+      return true;
+    case Type::STK_Floating:
+    case Type::STK_FloatingComplex:
+    case Type::STK_IntegralComplex:
+    case Type::STK_MemberPointer:
+      return false;
+    case Type::STK_FixedPoint:
+      llvm_unreachable("HLSL doesn't have fixed point types.");
+    }
+    llvm_unreachable("Should have returned before this");
+
+  case Type::STK_FixedPoint:
+    llvm_unreachable("HLSL doesn't have fixed point types.");
+
+  case Type::STK_Bool: // casting from bool is like casting from an integer
+  case Type::STK_Integral:
+    switch (DestTy->getScalarTypeKind()) {
+    case Type::STK_CPointer:
+    case Type::STK_ObjCObjectPointer:
+    case Type::STK_BlockPointer:
+    case Type::STK_Bool:
+    case Type::STK_Integral:
+    case Type::STK_Floating:
+    case Type::STK_IntegralComplex:
+    case Type::STK_FloatingComplex:
+      return true;
+    case Type::STK_FixedPoint:
+      llvm_unreachable("HLSL doesn't have fixed point types.");
+    case Type::STK_MemberPointer:
+      return false;
+    }
+    llvm_unreachable("Should have returned before this");
+
+  case Type::STK_Floating:
+    switch (DestTy->getScalarTypeKind()) {
+    case Type::STK_Floating:
+    case Type::STK_Bool:
+    case Type::STK_Integral:
+    case Type::STK_FloatingComplex:
+    case Type::STK_IntegralComplex:
+      return true;
+    case Type::STK_FixedPoint:
+      llvm_unreachable("HLSL doesn't have fixed point types.");
+    case Type::STK_CPointer:
+    case Type::STK_ObjCObjectPointer:
+    case Type::STK_BlockPointer:
+    case Type::STK_MemberPointer:
+      return false;
+    }
+    llvm_unreachable("Should have returned before this");
+
+  case Type::STK_FloatingComplex:
+    switch (DestTy->getScalarTypeKind()) {
+    case Type::STK_FloatingComplex:
+    case Type::STK_IntegralComplex:
+    case Type::STK_Floating:
+    case Type::STK_Bool:
+    case Type::STK_Integral:
+      return true;
+    case Type::STK_CPointer:
+    case Type::STK_ObjCObjectPointer:
+    case Type::STK_BlockPointer:
+    case Type::STK_MemberPointer:
+      return false;
+    case Type::STK_FixedPoint:
+      llvm_unreachable("HLSL doesn't have fixed point types.");
+    }
+    llvm_unreachable("Should have returned before this");
+
+  case Type::STK_IntegralComplex:
+    switch (DestTy->getScalarTypeKind()) {
+    case Type::STK_FloatingComplex:
+    case Type::STK_IntegralComplex:
+    case Type::STK_Integral:
+    case Type::STK_Bool:
+    case Type::STK_Floating:
+      return true;
+    case Type::STK_CPointer:
+    case Type::STK_ObjCObjectPointer:
+    case Type::STK_BlockPointer:
+    case Type::STK_MemberPointer:
+      return false;
+    case Type::STK_FixedPoint:
+      llvm_unreachable("HLSL doesn't have fixed point types.");
+    }
+    llvm_unreachable("Should have returned before this");
+  }
+
+  llvm_unreachable("Unhandled scalar cast");
+}
+
+// Can we perform an HLSL Flattened cast?
+bool SemaHLSL::CanPerformAggregateCast(Expr *Src, QualType DestTy) {
+
+  // Don't handle casts where LHS and RHS are any combination of scalar/vector
+  // There must be an aggregate somewhere
+  QualType SrcTy = Src->getType();
+  if (SrcTy->isScalarType()) // always a splat and this cast doesn't handle that
+    return false;
+  
+  if ((DestTy->isScalarType() || DestTy->isVectorType()) &&
+      (SrcTy->isScalarType() || SrcTy->isVectorType()))
+    return false;
+
+  llvm::SmallVector<QualType> DestTypes;
+  BuildFlattenedTypeList(DestTy, DestTypes);
+  llvm::SmallVector<QualType> SrcTypes;
+  BuildFlattenedTypeList(SrcTy, SrcTypes);
+
+  // Usually the size of SrcTypes must be greater than or equal to the size of DestTypes.
+  if (SrcTypes.size() >= DestTypes.size()) {
+
+    unsigned i;
+    for(i = 0; i < DestTypes.size() && i < SrcTypes.size(); i ++) {
+      if (!CanPerformScalarCast(SrcTypes[i], DestTypes[i])) {
+        return false;
+      }
+    }
+    return true;
+  } else { // can't cast, Src is wrong size for Dest
+    return false;
+  }
+}
+
 ExprResult SemaHLSL::ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg) {
   assert(Param->hasAttr<HLSLParamModifierAttr>() &&
          "We should not get here without a parameter modifier expression");
diff --git a/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp b/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp
index 7a900780384a91..067ff064861ce7 100644
--- a/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp
+++ b/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp
@@ -522,6 +522,7 @@ void ExprEngine::VisitCast(const CastExpr *CastE, const Expr *Ex,
       case CK_ToUnion:
       case CK_MatrixCast:
       case CK_VectorSplat:
+    case CK_HLSLAggregateCast:
       case CK_HLSLVectorTruncation: {
         QualType resultType = CastE->getType();
         if (CastE->isGLValue())

>From 121f2a9ac38f8a8098db51f3fd3ccdc6e3fa6f7b Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Thu, 5 Dec 2024 17:41:51 +0000
Subject: [PATCH 2/8] fix broken test

---
 clang/test/SemaHLSL/BuiltIns/vector-constructors-erros.hlsl | 2 --
 1 file changed, 2 deletions(-)

diff --git a/clang/test/SemaHLSL/BuiltIns/vector-constructors-erros.hlsl b/clang/test/SemaHLSL/BuiltIns/vector-constructors-erros.hlsl
index 7f6bdc7e67836b..b004acdc7c502c 100644
--- a/clang/test/SemaHLSL/BuiltIns/vector-constructors-erros.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/vector-constructors-erros.hlsl
@@ -17,6 +17,4 @@ void entry() {
   // These _should_ work in HLSL but aren't yet supported.
   S s;
   float2 GettingStrange = float2(s, s); // expected-error{{no viable conversion from 'S' to 'float'}} expected-error{{no viable conversion from 'S' to 'float'}}
-  S2 s2;
-  float2 EvenStranger = float2(s2); // expected-error{{cannot convert 'S2' to 'float2' (vector of 2 'float' values) without a conversion operator}}
 }

>From 9cc06ce79bbae61309ff0ab060e570d129fb0be8 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Thu, 5 Dec 2024 17:44:38 +0000
Subject: [PATCH 3/8] make clang format happy

---
 clang/lib/CodeGen/CGExpr.cpp                  | 36 +++++++-------
 clang/lib/CodeGen/CGExprAgg.cpp               | 48 +++++++++----------
 clang/lib/CodeGen/CGExprScalar.cpp            | 19 ++++----
 clang/lib/CodeGen/CodeGenFunction.h           | 11 +++--
 clang/lib/Sema/SemaCast.cpp                   |  6 ++-
 clang/lib/Sema/SemaHLSL.cpp                   |  7 +--
 clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp |  2 +-
 7 files changed, 66 insertions(+), 63 deletions(-)

diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index b7608b1226758d..6b9c437ef7e242 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -6360,7 +6360,8 @@ LValue CodeGenFunction::EmitPseudoObjectLValue(const PseudoObjectExpr *E) {
   return emitPseudoObjectExpr(*this, E, true, AggValueSlot::ignored()).LV;
 }
 
-llvm::Value* CodeGenFunction::PerformLoad(std::pair<Address, llvm::Value *> &GEP) {
+llvm::Value *
+CodeGenFunction::PerformLoad(std::pair<Address, llvm::Value *> &GEP) {
   Address GEPAddress = GEP.first;
   llvm::Value *Idx = GEP.second;
   llvm::Value *V = Builder.CreateLoad(GEPAddress, "load");
@@ -6370,8 +6371,9 @@ llvm::Value* CodeGenFunction::PerformLoad(std::pair<Address, llvm::Value *> &GEP
   return V;
 }
 
-llvm::Value* CodeGenFunction::PerformStore(std::pair<Address, llvm::Value *> &GEP,
-				           llvm::Value *Val) {
+llvm::Value *
+CodeGenFunction::PerformStore(std::pair<Address, llvm::Value *> &GEP,
+                              llvm::Value *Val) {
   Address GEPAddress = GEP.first;
   llvm::Value *Idx = GEP.second;
   if (Idx) {
@@ -6382,20 +6384,21 @@ llvm::Value* CodeGenFunction::PerformStore(std::pair<Address, llvm::Value *> &GE
   }
 }
 
-void CodeGenFunction::FlattenAccessAndType(Address Val, QualType SrcTy,
-			         SmallVector<llvm::Value *, 4> &IdxList,
-			         SmallVector<std::pair<Address, llvm::Value *>, 16> &GEPList,
-				 SmallVector<QualType> &FlatTypes) {
-  llvm::IntegerType *IdxTy = llvm::IntegerType::get(getLLVMContext(),32);
+void CodeGenFunction::FlattenAccessAndType(
+    Address Val, QualType SrcTy, SmallVector<llvm::Value *, 4> &IdxList,
+    SmallVector<std::pair<Address, llvm::Value *>, 16> &GEPList,
+    SmallVector<QualType> &FlatTypes) {
+  llvm::IntegerType *IdxTy = llvm::IntegerType::get(getLLVMContext(), 32);
   if (const ConstantArrayType *CAT = dyn_cast<ConstantArrayType>(SrcTy)) {
     uint64_t Size = CAT->getZExtSize();
-    for(unsigned i = 0; i < Size; i ++) {
+    for (unsigned i = 0; i < Size; i++) {
       // flatten each member of the array
       // add index of this element to index list
       llvm::Value *Idx = llvm::ConstantInt::get(IdxTy, i);
       IdxList.push_back(Idx);
       // recur on this object
-      FlattenAccessAndType(Val, CAT->getElementType(), IdxList, GEPList, FlatTypes);
+      FlattenAccessAndType(Val, CAT->getElementType(), IdxList, GEPList,
+                           FlatTypes);
       // remove index of this element from index list
       IdxList.pop_back();
     }
@@ -6405,7 +6408,7 @@ void CodeGenFunction::FlattenAccessAndType(Address Val, QualType SrcTy,
     // do I need to check if its a cxx record decl?
 
     for (auto fieldIter = Record->field_begin(), fieldEnd = Record->field_end();
-	 fieldIter != fieldEnd; ++fieldIter) {
+         fieldIter != fieldEnd; ++fieldIter) {
       // get the field number
       unsigned FieldNum = RL.getLLVMFieldNo(*fieldIter);
       // can we just do *fieldIter->getFieldIndex();
@@ -6414,16 +6417,16 @@ void CodeGenFunction::FlattenAccessAndType(Address Val, QualType SrcTy,
       IdxList.push_back(Idx);
       // recur on the field
       FlattenAccessAndType(Val, fieldIter->getType(), IdxList, GEPList,
-			   FlatTypes);
+                           FlatTypes);
       // remove index of this element from index list
       IdxList.pop_back();
     }
   } else if (const VectorType *VT = SrcTy->getAs<VectorType>()) {
     llvm::Type *VTy = ConvertTypeForMem(SrcTy);
     CharUnits Align = getContext().getTypeAlignInChars(SrcTy);
-    Address GEP = Builder.CreateInBoundsGEP(Val, IdxList,
-						 VTy, Align, "vector.gep");
-    for(unsigned i = 0; i < VT->getNumElements(); i ++) {
+    Address GEP =
+        Builder.CreateInBoundsGEP(Val, IdxList, VTy, Align, "vector.gep");
+    for (unsigned i = 0; i < VT->getNumElements(); i++) {
       // add index to the list
       llvm::Value *Idx = llvm::ConstantInt::get(IdxTy, i);
       // create gep. no need to recur since its always a scalar
@@ -6435,8 +6438,7 @@ void CodeGenFunction::FlattenAccessAndType(Address Val, QualType SrcTy,
     // create a gep
     llvm::Type *Ty = ConvertTypeForMem(SrcTy);
     CharUnits Align = getContext().getTypeAlignInChars(SrcTy);
-    Address GEP = Builder.CreateInBoundsGEP(Val, IdxList,
-						     Ty, Align,  "gep");
+    Address GEP = Builder.CreateInBoundsGEP(Val, IdxList, Ty, Align, "gep");
     GEPList.push_back({GEP, NULL});
     FlatTypes.push_back(SrcTy);
   }
diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index bc8e1f0f9248ef..e3b47de958ce55 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -491,50 +491,45 @@ static bool isTrivialFiller(Expr *E) {
   return false;
 }
 
-
-
 // emit a flat cast where the RHS is a scalar, including vector
 static void EmitHLSLScalarFlatCast(CodeGenFunction &CGF, Address DestVal,
-			    QualType DestTy, llvm::Value *SrcVal,
-			    QualType SrcTy, SourceLocation Loc) {
+                                   QualType DestTy, llvm::Value *SrcVal,
+                                   QualType SrcTy, SourceLocation Loc) {
   // Flatten our destination
   SmallVector<QualType> DestTypes; // Flattened type
   SmallVector<llvm::Value *, 4> IdxList;
   SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
   // ^^ Flattened accesses to DestVal we want to store into
-  CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList,
-		       DestTypes);
+  CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList, DestTypes);
 
   if (const VectorType *VT = SrcTy->getAs<VectorType>()) {
     SrcTy = VT->getElementType();
     assert(StoreGEPList.size() <= VT->getNumElements() &&
-	   "Cannot perform HLSL flat cast when vector source \
+           "Cannot perform HLSL flat cast when vector source \
            object has less elements than flattened destination \
            object.");
-      for(unsigned i = 0; i < StoreGEPList.size(); i ++) {
-        llvm::Value *Load = CGF.Builder.CreateExtractElement(SrcVal, i,
-							     "vec.load");
-	llvm::Value *Cast = CGF.EmitScalarConversion(Load, SrcTy,
-						     DestTypes[i],
-						     Loc);
-	CGF.PerformStore(StoreGEPList[i], Cast);
-      }
-      return;
+    for (unsigned i = 0; i < StoreGEPList.size(); i++) {
+      llvm::Value *Load =
+          CGF.Builder.CreateExtractElement(SrcVal, i, "vec.load");
+      llvm::Value *Cast =
+          CGF.EmitScalarConversion(Load, SrcTy, DestTypes[i], Loc);
+      CGF.PerformStore(StoreGEPList[i], Cast);
+    }
+    return;
   }
   llvm_unreachable("HLSL Flat cast doesn't handle splatting.");
 }
 
 // emit a flat cast where the RHS is an aggregate
 static void EmitHLSLAggregateFlatCast(CodeGenFunction &CGF, Address DestVal,
-			       QualType DestTy, Address SrcVal,
-			       QualType SrcTy, SourceLocation Loc) {
+                                      QualType DestTy, Address SrcVal,
+                                      QualType SrcTy, SourceLocation Loc) {
   // Flatten our destination
   SmallVector<QualType> DestTypes; // Flattened type
   SmallVector<llvm::Value *, 4> IdxList;
   SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
   // ^^ Flattened accesses to DestVal we want to store into
-  CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList,
-		       DestTypes);
+  CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList, DestTypes);
   // Flatten our src
   SmallVector<QualType> SrcTypes; // Flattened type
   SmallVector<std::pair<Address, llvm::Value *>, 16> LoadGEPList;
@@ -543,14 +538,14 @@ static void EmitHLSLAggregateFlatCast(CodeGenFunction &CGF, Address DestVal,
   CGF.FlattenAccessAndType(SrcVal, SrcTy, IdxList, LoadGEPList, SrcTypes);
 
   assert(StoreGEPList.size() <= LoadGEPList.size() &&
-	 "Cannot perform HLSL flat cast when flattened source object \
+         "Cannot perform HLSL flat cast when flattened source object \
           has less elements than flattened destination object.");
   // apply casts to what we load from LoadGEPList
   // and store result in Dest
-  for(unsigned i = 0; i < StoreGEPList.size(); i ++) {
+  for (unsigned i = 0; i < StoreGEPList.size(); i++) {
     llvm::Value *Load = CGF.PerformLoad(LoadGEPList[i]);
-    llvm::Value *Cast = CGF.EmitScalarConversion(Load, SrcTypes[i],
-						 DestTypes[i], Loc);
+    llvm::Value *Cast =
+        CGF.EmitScalarConversion(Load, SrcTypes[i], DestTypes[i], Loc);
     CGF.PerformStore(StoreGEPList[i], Cast);
   }
 }
@@ -967,11 +962,12 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
       EmitHLSLScalarFlatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
     } else { // RHS is an aggregate
       assert(RV.isAggregate() &&
-	     "Can't perform HLSL Aggregate cast on a complex type.");
+             "Can't perform HLSL Aggregate cast on a complex type.");
       Address SrcVal = RV.getAggregateAddress();
       EmitHLSLAggregateFlatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
     }
-    break; }
+    break;
+  }
   case CK_NoOp:
   case CK_UserDefinedConversion:
   case CK_ConstructorConversion:
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index d7bb702ec3ca20..3809e3b1db3494 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -2264,8 +2264,8 @@ bool CodeGenFunction::ShouldNullCheckClassCastValue(const CastExpr *CE) {
 
 // RHS is an aggregate type
 static Value *EmitHLSLAggregateFlatCast(CodeGenFunction &CGF, Address RHSVal,
-					QualType RHSTy, QualType LHSTy,
-					SourceLocation Loc) {
+                                        QualType RHSTy, QualType LHSTy,
+                                        SourceLocation Loc) {
   SmallVector<llvm::Value *, 4> IdxList;
   SmallVector<std::pair<Address, llvm::Value *>, 16> LoadGEPList;
   SmallVector<QualType> SrcTypes; // Flattened type
@@ -2273,22 +2273,23 @@ static Value *EmitHLSLAggregateFlatCast(CodeGenFunction &CGF, Address RHSVal,
   // LHS is either a vector or a builtin?
   // if its a vector create a temp alloca to store into and return that
   if (auto *VecTy = LHSTy->getAs<VectorType>()) {
-    llvm::Value *V = CGF.Builder.CreateLoad(CGF.CreateIRTemp(LHSTy, "flatcast.tmp"));
+    llvm::Value *V =
+        CGF.Builder.CreateLoad(CGF.CreateIRTemp(LHSTy, "flatcast.tmp"));
     // write to V.
-    for(unsigned i = 0; i < VecTy->getNumElements(); i ++) {
+    for (unsigned i = 0; i < VecTy->getNumElements(); i++) {
       llvm::Value *Load = CGF.PerformLoad(LoadGEPList[i]);
-      llvm::Value *Cast = CGF.EmitScalarConversion(Load, SrcTypes[i],
-						   VecTy->getElementType(), Loc);
+      llvm::Value *Cast = CGF.EmitScalarConversion(
+          Load, SrcTypes[i], VecTy->getElementType(), Loc);
       V = CGF.Builder.CreateInsertElement(V, Cast, i);
     }
     return V;
   }
   // i its a builtin just do an extract element or load.
   assert(LHSTy->isBuiltinType() &&
-	 "Destination type must be a vector or builtin type.");
+         "Destination type must be a vector or builtin type.");
   // TODO add asserts about things being long enough
-  return CGF.EmitScalarConversion(CGF.PerformLoad(LoadGEPList[0]),
-				  LHSTy, SrcTypes[0], Loc);
+  return CGF.EmitScalarConversion(CGF.PerformLoad(LoadGEPList[0]), LHSTy,
+                                  SrcTypes[0], Loc);
 }
 
 // VisitCastExpr - Emit code for an explicit or implicit cast.  Implicit casts
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index b17ead377610e6..873dd781eb2e7d 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4360,11 +4360,12 @@ class CodeGenFunction : public CodeGenTypeCache {
   LValue EmitPseudoObjectLValue(const PseudoObjectExpr *e);
 
   llvm::Value *PerformLoad(std::pair<Address, llvm::Value *> &GEP);
-  llvm::Value *PerformStore(std::pair<Address, llvm::Value *> &GEP, llvm::Value *Val);
-  void FlattenAccessAndType(Address Val, QualType SrcTy,
-			    SmallVector<llvm::Value *, 4> &IdxList,
-			    SmallVector<std::pair<Address, llvm::Value *>, 16> &GEPList,
-			    SmallVector<QualType> &FlatTypes);
+  llvm::Value *PerformStore(std::pair<Address, llvm::Value *> &GEP,
+                            llvm::Value *Val);
+  void FlattenAccessAndType(
+      Address Val, QualType SrcTy, SmallVector<llvm::Value *, 4> &IdxList,
+      SmallVector<std::pair<Address, llvm::Value *>, 16> &GEPList,
+      SmallVector<QualType> &FlatTypes);
 
   llvm::Value *EmitIvarOffset(const ObjCInterfaceDecl *Interface,
                               const ObjCIvarDecl *Ivar);
diff --git a/clang/lib/Sema/SemaCast.cpp b/clang/lib/Sema/SemaCast.cpp
index 955c44cf4a6a42..0bd7fc91aee18f 100644
--- a/clang/lib/Sema/SemaCast.cpp
+++ b/clang/lib/Sema/SemaCast.cpp
@@ -23,9 +23,9 @@
 #include "clang/Basic/TargetInfo.h"
 #include "clang/Lex/Preprocessor.h"
 #include "clang/Sema/Initialization.h"
+#include "clang/Sema/SemaHLSL.h"
 #include "clang/Sema/SemaObjC.h"
 #include "clang/Sema/SemaRISCV.h"
-#include "clang/Sema/SemaHLSL.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"
 #include <set>
@@ -2780,7 +2780,9 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
   if (Self.getLangOpts().HLSL &&
       Self.HLSL().CanPerformAggregateCast(SrcExpr.get(), DestType)) {
     if (SrcTy->isConstantArrayType())
-      SrcExpr = Self.ImpCastExprToType(SrcExpr.get(), Self.Context.getArrayParameterType(SrcTy), CK_HLSLArrayRValue, VK_PRValue, nullptr, CCK);
+      SrcExpr = Self.ImpCastExprToType(
+          SrcExpr.get(), Self.Context.getArrayParameterType(SrcTy),
+          CK_HLSLArrayRValue, VK_PRValue, nullptr, CCK);
     Kind = CK_HLSLAggregateCast;
     return;
   }
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 942c0a8fcaab09..5c7af8056063ad 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2530,7 +2530,7 @@ bool SemaHLSL::CanPerformAggregateCast(Expr *Src, QualType DestTy) {
   QualType SrcTy = Src->getType();
   if (SrcTy->isScalarType()) // always a splat and this cast doesn't handle that
     return false;
-  
+
   if ((DestTy->isScalarType() || DestTy->isVectorType()) &&
       (SrcTy->isScalarType() || SrcTy->isVectorType()))
     return false;
@@ -2540,11 +2540,12 @@ bool SemaHLSL::CanPerformAggregateCast(Expr *Src, QualType DestTy) {
   llvm::SmallVector<QualType> SrcTypes;
   BuildFlattenedTypeList(SrcTy, SrcTypes);
 
-  // Usually the size of SrcTypes must be greater than or equal to the size of DestTypes.
+  // Usually the size of SrcTypes must be greater than or equal to the size of
+  // DestTypes.
   if (SrcTypes.size() >= DestTypes.size()) {
 
     unsigned i;
-    for(i = 0; i < DestTypes.size() && i < SrcTypes.size(); i ++) {
+    for (i = 0; i < DestTypes.size() && i < SrcTypes.size(); i++) {
       if (!CanPerformScalarCast(SrcTypes[i], DestTypes[i])) {
         return false;
       }
diff --git a/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp b/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp
index 067ff064861ce7..b105c196fc3bfb 100644
--- a/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp
+++ b/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp
@@ -522,7 +522,7 @@ void ExprEngine::VisitCast(const CastExpr *CastE, const Expr *Ex,
       case CK_ToUnion:
       case CK_MatrixCast:
       case CK_VectorSplat:
-    case CK_HLSLAggregateCast:
+      case CK_HLSLAggregateCast:
       case CK_HLSLVectorTruncation: {
         QualType resultType = CastE->getType();
         if (CastE->isGLValue())

>From e3e51b6761f2e9af61bfa6ae87860e05484e93c0 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Thu, 5 Dec 2024 17:46:16 +0000
Subject: [PATCH 4/8] CodeGen tests

---
 .../BasicFeatures/ArrayFlatCast.hlsl          | 128 ++++++++++++++++++
 .../BasicFeatures/StructFlatCast.hlsl         | 124 +++++++++++++++++
 .../BasicFeatures/VectorFlatCast.hlsl         |  81 +++++++++++
 3 files changed, 333 insertions(+)
 create mode 100644 clang/test/CodeGenHLSL/BasicFeatures/ArrayFlatCast.hlsl
 create mode 100644 clang/test/CodeGenHLSL/BasicFeatures/StructFlatCast.hlsl
 create mode 100644 clang/test/CodeGenHLSL/BasicFeatures/VectorFlatCast.hlsl

diff --git a/clang/test/CodeGenHLSL/BasicFeatures/ArrayFlatCast.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/ArrayFlatCast.hlsl
new file mode 100644
index 00000000000000..23a71a2ecc6b96
--- /dev/null
+++ b/clang/test/CodeGenHLSL/BasicFeatures/ArrayFlatCast.hlsl
@@ -0,0 +1,128 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -disable-llvm-passes -emit-llvm -finclude-default-header -o - %s | FileCheck %s
+
+// array truncation
+// CHECK-LABEL: define void {{.*}}call1
+// CHECK: [[A:%.*]] = alloca [2 x i32], align 4
+// CHECK-NEXT: [[B:%.*]] = alloca [1 x i32], align 4
+// CHECK-NEXT: [[Tmp:%.*]] = alloca [2 x i32], align 4
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[A]], ptr align 4 {{.*}}, i32 8, i1 false)
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 4, i1 false)
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[Tmp]], ptr align 4 [[A]], i32 8, i1 false)
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [1 x i32], ptr [[B]], i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[Tmp]], i32 0
+// CHECK-NEXT: [[G3:%.*]] = getelementptr inbounds [2 x i32], ptr [[Tmp]], i32 1
+// CHECK-NEXT: [[L:%.*]] = load i32, ptr [[G2]], align 4
+// CHECK-NEXT: store i32 [[L]], ptr [[G1]], align 4
+export void call1() {
+  int A[2] = {0,1};
+  int B[1] = {4};
+  B = (int[1])A;
+}
+
+// just a cast
+// CHECK-LABEL: define void {{.*}}call2
+// CHECK: [[A:%.*]] = alloca [1 x i32], align 4
+// CHECK-NEXT: [[B:%.*]] = alloca [1 x float], align 4
+// CHECK-NEXT: [[Tmp:%.*]] = alloca [1 x i32], align 4
+// CHECK-NEXT: call void @llvm.memset.p0.i32(ptr align 4 [[A]], i8 0, i32 4, i1 false)
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 4, i1 false)
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[Tmp]], ptr align 4 [[A]], i32 4, i1 false)
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [1 x float], ptr [[B]], i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [1 x i32], ptr [[Tmp]], i32 0
+// CHECK-NEXT: [[L:%.*]] = load i32, ptr [[G2]], align 4
+// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[L]] to float
+// CHECK-NEXT: store float [[C]], ptr [[G1]], align 4
+export void call2() {
+  int A[1] = {0};
+  float B[1] = {1.0};
+  B = (float[1])A;
+}
+
+// vector to array
+// CHECK-LABEL: define void {{.*}}call3
+// CHECK: [[A:%.*]] = alloca <1 x float>, align 4
+// CHECK-NEXT: [[B:%.*]] = alloca [1 x i32], align 4
+// CHECK-NEXT: store <1 x float> splat (float 0x3FF3333340000000), ptr [[A]], align 4
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 4, i1 false)
+// CHECK-NEXT: [[C:%.*]] = load <1 x float>, ptr [[A]], align 4
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [1 x i32], ptr [[B]], i32 0
+// CHECK-NEXT: [[V:%.*]] = extractelement <1 x float> [[C]], i64 0
+// CHECK-NEXT: [[C:%.*]] = fptosi float [[V]] to i32
+// CHECK-NEXT: store i32 [[C]], ptr [[G1]], align 4
+export void call3() {
+  float1 A = {1.2};
+  int B[1] = {1};
+  B = (int[1])A;
+}
+
+// flatten array of vector to array with cast
+// CHECK-LABEL: define void {{.*}}call5
+// CHECK: [[A:%.*]] = alloca [1 x <2 x float>], align 8
+// CHECK-NEXT: [[B:%.*]] = alloca [2 x i32], align 4
+// CHECK-NEXT: [[Tmp:%.*]] = alloca [1 x <2 x float>], align 8
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 8 [[A]], ptr align 8 {{.*}}, i32 8, i1 false)
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false)
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 8 [[Tmp]], ptr align 8 [[A]], i32 8, i1 false)
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 1
+// CHECK-NEXT: [[VG:%.*]] = getelementptr inbounds [1 x <2 x float>], ptr [[Tmp]], i32 0
+// CHECK-NEXT: [[L:%.*]] = load <2 x float>, ptr [[VG]], align 8
+// CHECK-NEXT: [[VL:%.*]] = extractelement <2 x float> [[L]], i32 0
+// CHECK-NEXT: [[C:%.*]] = fptosi float [[VL]] to i32
+// CHECK-NEXT: store i32 [[C]], ptr [[G1]], align 4
+// CHECK-NEXT: [[L4:%.*]] = load <2 x float>, ptr [[VG]], align 8
+// CHECK-NEXT: [[VL5:%.*]] = extractelement <2 x float> [[L4]], i32 1
+// CHECK-NEXT: [[C6:%.*]] = fptosi float [[VL5]] to i32
+// CHECK-NEXT: store i32 [[C6]], ptr [[G2]], align 4
+export void call5() {
+  float2 A[1] = {{1.2,3.4}};
+  int B[2] = {1,2};
+  B = (int[2])A;
+}
+
+// flatten 2d array
+// CHECK-LABEL: define void {{.*}}call6
+// CHECK: [[A:%.*]] = alloca [2 x [1 x i32]], align 4
+// CHECK-NEXT: [[B:%.*]] = alloca [2 x i32], align 4
+// CHECK-NEXT: [[Tmp:%.*]] = alloca [2 x [1 x i32]], align 4
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[A]], ptr align 4 {{.*}}, i32 8, i1 false)
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false)
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[Tmp]], ptr align 4 [[A]], i32 8, i1 false)
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 1
+// CHECK-NEXT: [[G3:%.*]] = getelementptr inbounds [2 x [1 x i32]], ptr [[Tmp]], i32 0, i32 0
+// CHECK-NEXT: [[G4:%.*]] = getelementptr inbounds [2 x [1 x i32]], ptr [[Tmp]], i32 1, i32 0
+// CHECK-NEXT: [[L:%.*]] = load i32, ptr [[G3]], align 4
+// CHECK-NEXT: store i32 [[L]], ptr [[G1]], align 4
+// CHECK-NEXT: [[L4:%.*]] = load i32, ptr [[G4]], align 4
+// CHECK-NEXT: store i32 [[L4]], ptr [[G2]], align 4
+export void call6() {
+  int A[2][1] = {{1},{3}};
+  int B[2] = {1,2};
+  B = (int[2])A;
+}
+
+struct S {
+  int X;
+  float Y;
+};
+
+// flatten and truncate from a struct
+// CHECK-LABEL: define void {{.*}}call7
+// CHECK: [[s:%.*]] = alloca %struct.S, align 4
+// CHECK-NEXT: [[A:%.*]] = alloca [1 x i32], align 4
+// CHECK-NEXT: [[Tmp:%.*]] = alloca %struct.S, align 4
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[s]], ptr align 4 {{.*}}, i32 8, i1 false)
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[A]], ptr align 4 {{.*}}, i32 4, i1 false)
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[Tmp]], ptr align 4 [[s]], i32 8, i1 false)
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [1 x i32], ptr [[A]], i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[Tmp]], i32 0
+// CHECK-NEXT: [[G3:%.*]] = getelementptr inbounds %struct.S, ptr [[Tmp]], i32 1
+// CHECK-NEXT: [[L:%.*]] = load i32, ptr [[G2]], align 4
+// CHECK-NEXT: store i32 [[L]], ptr [[G1]], align 4
+export void call7() {
+  S s = {1, 2.9};
+  int A[1] = {1};
+  A = (int[1])s;
+}
+
diff --git a/clang/test/CodeGenHLSL/BasicFeatures/StructFlatCast.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/StructFlatCast.hlsl
new file mode 100644
index 00000000000000..c44e340109abb2
--- /dev/null
+++ b/clang/test/CodeGenHLSL/BasicFeatures/StructFlatCast.hlsl
@@ -0,0 +1,124 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s
+
+struct S {
+  int X;
+  float Y;
+};
+
+// struct from vector
+// CHECK-LABEL: define void {{.*}}call1
+// CHECK: [[A:%.*]] = alloca <2 x i32>, align 8
+// CHECK-NEXT: [[s:%.*]] = alloca %struct.S, align 4
+// CHECK-NEXT: store <2 x i32> <i32 1, i32 2>, ptr [[A]], align 8
+// CHECK-NEXT: [[L:%.*]] = load <2 x i32>, ptr [[A]], align 8
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 1
+// CHECK-NEXT: [[VL:%.*]] = extractelement <2 x i32> [[L]], i64 0
+// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
+// CHECK-NEXT: [[VL2:%.*]] = extractelement <2 x i32> [[L]], i64 1
+// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL2]] to float
+// CHECK-NEXT: store float [[C]], ptr [[G2]], align 4
+export void call1() {
+  int2 A = {1,2};
+  S s = (S)A;
+}
+
+
+// struct from array
+// CHECK-LABEL: define void {{.*}}call2
+// CHECK: [[A:%.*]] = alloca [2 x i32], align 4
+// CHECK-NEXT: [[s:%.*]] = alloca %struct.S, align 4
+// CHECK-NEXT: [[Tmp:%.*]] = alloca [2 x i32], align 4
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[A]], ptr align 4 {{.*}}, i32 8, i1 false)
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[Tmp]], ptr align 4 [[A]], i32 8, i1 false)
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 1
+// CHECK-NEXT: [[G3:%.*]] = getelementptr inbounds [2 x i32], ptr [[Tmp]], i32 0
+// CHECK-NEXT: [[G4:%.*]] = getelementptr inbounds [2 x i32], ptr [[Tmp]], i32 1
+// CHECK-NEXT: [[L:%.*]] = load i32, ptr [[G3]], align 4
+// CHECK-NEXT: store i32 [[L]], ptr [[G1]], align 4
+// CHECK-NEXT: [[L4:%.*]] = load i32, ptr [[G4]], align 4
+// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[L4]] to float
+// CHECK-NEXT: store float [[C]], ptr [[G2]], align 4
+export void call2() {
+  int A[2] = {1,2};
+  S s = (S)A;
+}
+
+struct Q {
+  int Z;
+};
+
+struct R {
+  Q q;
+  float F;
+};
+
+// struct from nested struct?
+// CHECK-LABEL: define void {{.*}}call6
+// CHECK: [[r:%.*]] = alloca %struct.R, align 4
+// CHECK-NEXT: [[s:%.*]] = alloca %struct.S, align 4
+// CHECK-NEXT: [[Tmp:%.*]] = alloca %struct.R, align 4
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[r]], ptr align 4 {{.*}}, i32 8, i1 false)
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[Tmp]], ptr align 4 [[r]], i32 8, i1 false)
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 1
+// CHECK-NEXT: [[G3:%.*]] = getelementptr inbounds %struct.R, ptr [[Tmp]], i32 0, i32 0
+// CHECK-NEXT: [[G4:%.*]] = getelementptr inbounds %struct.R, ptr [[Tmp]], i32 1
+// CHECK-NEXT: [[L:%.*]] = load i32, ptr [[G3]], align 4
+// CHECK-NEXT: store i32 [[L]], ptr [[G1]], align 4
+// CHECK-NEXT: [[L4:%.*]] = load float, ptr [[G4]], align 4
+// CHECK-NEXT: store float [[L4]], ptr [[G2]], align 4
+export void call6() {
+  R r = {{1}, 2.0};
+  S s = (S)r;
+}
+
+// nested struct from array?
+// CHECK-LABEL: define void {{.*}}call7
+// CHECK: [[A:%.*]] = alloca [2 x i32], align 4
+// CHECK-NEXT: [[r:%.*]] = alloca %struct.R, align 4
+// CHECK-NEXT: [[Tmp:%.*]] = alloca [2 x i32], align 4
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[A]], ptr align 4 {{.*}}, i32 8, i1 false)
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[Tmp]], ptr align 4 [[A]], i32 8, i1 false)
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.R, ptr [[r]], i32 0, i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.R, ptr [[r]], i32 1
+// CHECK-NEXT: [[G3:%.*]] = getelementptr inbounds [2 x i32], ptr [[Tmp]], i32 0
+// CHECK-NEXT: [[G4:%.*]] = getelementptr inbounds [2 x i32], ptr [[Tmp]], i32 1
+// CHECK-NEXT: [[L:%.*]] = load i32, ptr [[G3]], align 4
+// CHECK-NEXT: store i32 [[L]], ptr [[G1]], align 4
+// CHECK-NEXT: [[L4:%.*]] = load i32, ptr [[G4]], align 4
+// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[L4]] to float
+// CHECK-NEXT: store float [[C]], ptr [[G2]], align 4
+export void call7() {
+  int A[2] = {1,2};
+  R r = (R)A;
+}
+
+struct T {
+  int A;
+  int B;
+  int C;
+};
+
+// struct truncation
+// CHECK-LABEL: define void {{.*}}call8
+// CHECK: [[t:%.*]] = alloca %struct.T, align 4
+// CHECK-NEXT: [[s:%.*]] = alloca %struct.S, align 4
+// CHECK-NEXT: [[Tmp:%.*]] = alloca %struct.T, align 4
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[t]], ptr align 4 {{.*}}, i32 12, i1 false)
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[Tmp]], ptr align 4 [[t]], i32 12, i1 false)
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 1
+// CHECK-NEXT: [[G3:%.*]] = getelementptr inbounds %struct.T, ptr [[Tmp]], i32 0
+// CHECK-NEXT: %gep3 = getelementptr inbounds %struct.T, ptr %agg-temp, i32 1
+// CHECK-NEXT: %gep4 = getelementptr inbounds %struct.T, ptr %agg-temp, i32 2
+// CHECK-NEXT: %load = load i32, ptr %gep2, align 4
+// CHECK-NEXT: store i32 %load, ptr %gep, align 4
+// CHECK-NEXT: %load5 = load i32, ptr %gep3, align 4
+// CHECK-NEXT: %conv = sitofp i32 %load5 to float
+// CHECK-NEXT: store float %conv, ptr %gep1, align 4
+export void call8() {
+  T t = {1,2,3};
+  S s = (S)t;
+}
diff --git a/clang/test/CodeGenHLSL/BasicFeatures/VectorFlatCast.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/VectorFlatCast.hlsl
new file mode 100644
index 00000000000000..9cd320ee9f62db
--- /dev/null
+++ b/clang/test/CodeGenHLSL/BasicFeatures/VectorFlatCast.hlsl
@@ -0,0 +1,81 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s
+
+// vector flat cast from array
+// CHECK-LABEL: define void {{.*}}call2
+// CHECK: [[A:%.*]] = alloca [2 x [1 x i32]], align 4
+// CHECK-NEXT: [[B:%.*]] = alloca <2 x i32>, align 8
+// CHECK-NEXT: [[Tmp:%.*]] = alloca [2 x [1 x i32]], align 4
+// CHECK-NEXT: [[Tmp2:%.*]] = alloca <2 x i32>, align 8
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[A]], ptr align 4 {{.*}}, i32 8, i1 false)
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[Tmp]], ptr align 4 [[A]], i32 8, i1 false)
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x [1 x i32]], ptr [[Tmp]], i32 0, i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x [1 x i32]], ptr [[Tmp]], i32 1, i32 0
+// CHECK-NEXT: [[C:%.*]] = load <2 x i32>, ptr [[Tmp2]], align 8
+// CHECK-NEXT: [[L:%.*]] = load i32, ptr [[G1]], align 4
+// CHECK-NEXT: [[D:%.*]] = insertelement <2 x i32> [[C]], i32 [[L]], i64 0
+// CHECK-NEXT: [[L2:%.*]] = load i32, ptr [[G2]], align 4
+// CHECK-NEXT: [[E:%.*]] = insertelement <2 x i32> [[D]], i32 [[L2]], i64 1
+// CHECK-NEXT: store <2 x i32> [[E]], ptr [[B]], align 8
+export void call2() {
+  int A[2][1] = {{1},{2}};
+  int2 B = (int2)A;
+}
+
+struct S {
+  int X;
+  float Y;
+};
+
+// vector flat cast from struct
+// CHECK-LABEL: define void {{.*}}call3
+// CHECK: [[s:%.*]] = alloca %struct.S, align 4
+// CHECK-NEXT: [[A:%.*]] = alloca <2 x i32>, align 8
+// CHECK-NEXT: [[Tmp:%.*]] = alloca %struct.S, align 4
+// CHECK-NEXT: [[Tmp2:%.*]] = alloca <2 x i32>, align 8
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[s]], ptr align 4 {{.*}}, i32 8, i1 false)
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[Tmp]], ptr align 4 [[s]], i32 8, i1 false)
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[Tmp]], i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[Tmp]], i32 1
+// CHECK-NEXT: [[B:%.*]] = load <2 x i32>, ptr [[Tmp2]], align 8
+// CHECK-NEXT: [[L:%.*]] = load i32, ptr [[G1]], align 4
+// CHECK-NEXT: [[C:%.*]] = insertelement <2 x i32> [[B]], i32 [[L]], i64 0
+// CHECK-NEXT: [[L2:%.*]] = load float, ptr [[G2]], align 4
+// CHECK-NEXT: [[D:%.*]] = fptosi float [[L2]] to i32
+// CHECK-NEXT: [[E:%.*]] = insertelement <2 x i32> [[C]], i32 [[D]], i64 1
+// CHECK-NEXT: store <2 x i32> [[E]], ptr [[A]], align 8
+export void call3() {
+  S s = {1, 2.0};
+  int2 A = (int2)s;
+}
+
+// truncate array to scalar
+// CHECK-LABEL: define void {{.*}}call4
+// CHECK: [[A:%.*]] = alloca [2 x i32], align 4
+// CHECK-NEXT: [[B:%.*]] = alloca i32, align 4
+// CHECK-NEXT: [[Tmp:%.*]] = alloca [2 x i32], align 4
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[A]], ptr align 4 {{.*}}, i32 8, i1 false)
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[Tmp]], ptr align 4 [[A]], i32 8, i1 false)
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[Tmp]], i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[Tmp]], i32 1
+// CHECK-NEXT: [[L:%.*]] = load i32, ptr [[G1]], align 4
+// CHECK-NEXT: store i32 [[L]], ptr [[B]], align 4
+export void call4() {
+ int A[2] = {1,2};
+ int B = (int)A;
+}
+
+// truncate struct to scalar
+// CHECK-LABEL: define void {{.*}}call5
+// CHECK: [[s:%.*]] = alloca %struct.S, align 4
+// CHECK-NEXT: [[A:%.*]] = alloca i32, align 4
+// CHECK-NEXT: [[Tmp:%.*]] = alloca %struct.S, align 4
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[s]], ptr align 4 {{.*}}, i32 8, i1 false)
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[Tmp]], ptr align 4 [[s]], i32 8, i1 false)
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[Tmp]], i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[Tmp]], i32 1
+// CHECK-NEXT: [[L:%.*]] = load i32, ptr [[G1]], align 4
+// CHECK-NEXT: store i32 [[L]], ptr [[A]], align 4
+export void call5() {
+ S s = {1, 2.0};
+ int A = (int)s;
+}

>From 63d7f2a0f7dc6a579340798d6a5f2e1ab3c96282 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Fri, 6 Dec 2024 05:14:17 +0000
Subject: [PATCH 5/8] splat cast wip

---
 clang/include/clang/AST/OperationKinds.def |  3 ++
 clang/include/clang/Sema/SemaHLSL.h        |  1 +
 clang/lib/CodeGen/CGExprAgg.cpp            | 42 ++++++++++++++++++++++
 clang/lib/CodeGen/CGExprScalar.cpp         | 16 +++++++++
 clang/lib/Sema/Sema.cpp                    |  1 +
 clang/lib/Sema/SemaCast.cpp                | 10 +++---
 clang/lib/Sema/SemaHLSL.cpp                | 26 ++++++++++++++
 7 files changed, 95 insertions(+), 4 deletions(-)

diff --git a/clang/include/clang/AST/OperationKinds.def b/clang/include/clang/AST/OperationKinds.def
index 9323d4e861a734..84e9e635b276c2 100644
--- a/clang/include/clang/AST/OperationKinds.def
+++ b/clang/include/clang/AST/OperationKinds.def
@@ -370,6 +370,9 @@ CAST_OPERATION(HLSLArrayRValue)
 // Aggregate by Value cast (HLSL only).
 CAST_OPERATION(HLSLAggregateCast)
 
+// Splat cast for Aggregates (HLSL only).
+CAST_OPERATION(HLSLSplatCast)
+
 //===- Binary Operations  -------------------------------------------------===//
 // Operators listed in order of precedence.
 // Note that additions to this should also update the StmtVisitor class,
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index 6bda1e8ce0ea5b..1482ef7b4294d6 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -142,6 +142,7 @@ class SemaHLSL : public SemaBase {
 
   bool CanPerformScalarCast(QualType SrcTy, QualType DestTy);
   bool CanPerformAggregateCast(Expr *Src, QualType DestType);
+  bool CanPerformSplat(Expr *Src, QualType DestType);
   ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg);
 
   QualType getInoutParameterType(QualType Ty);
diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index e3b47de958ce55..2642fe8328057d 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -491,6 +491,33 @@ static bool isTrivialFiller(Expr *E) {
   return false;
 }
 
+static void EmitHLSLSplatCast(CodeGenFunction &CGF, Address DestVal,
+			      QualType DestTy, llvm::Value *SrcVal,
+			      QualType SrcTy, SourceLocation Loc) {
+  // Flatten our destination
+  SmallVector<QualType> DestTypes; // Flattened type
+  SmallVector<llvm::Value *, 4> IdxList;
+  SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
+  // ^^ Flattened accesses to DestVal we want to store into
+  CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList,
+		       DestTypes);
+
+  if (const VectorType *VT = SrcTy->getAs<VectorType>()) {
+    assert(VT->getNumElements() == 1 && "Invalid HLSL splat cast.");
+
+    SrcTy = VT->getElementType();
+    SrcVal = CGF.Builder.CreateExtractElement(SrcVal, (uint64_t)0,
+					      "vec.load");
+  }
+  assert(SrcTy->isScalarType() && "Invalid HLSL splat cast.");
+  for(unsigned i = 0; i < StoreGEPList.size(); i ++) {
+    llvm::Value *Cast = CGF.EmitScalarConversion(SrcVal, SrcTy,
+						 DestTypes[i],
+						 Loc);
+    CGF.PerformStore(StoreGEPList[i], Cast);
+  }
+}
+
 // emit a flat cast where the RHS is a scalar, including vector
 static void EmitHLSLScalarFlatCast(CodeGenFunction &CGF, Address DestVal,
                                    QualType DestTy, llvm::Value *SrcVal,
@@ -949,6 +976,21 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
   case CK_HLSLArrayRValue:
     Visit(E->getSubExpr());
     break;
+  case CK_HLSLSplatCast: {
+    Expr *Src = E->getSubExpr();
+    QualType SrcTy = Src->getType();
+    RValue RV = CGF.EmitAnyExpr(Src);
+    QualType DestTy = E->getType();
+    Address DestVal = Dest.getAddress();
+    SourceLocation Loc = E->getExprLoc();
+    
+    if (RV.isScalar()) {
+      llvm::Value *SrcVal = RV.getScalarVal();
+      EmitHLSLSplatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
+      break;
+    }
+    llvm_unreachable("RHS of HLSL splat cast must be a scalar or vector.");
+  }
   case CK_HLSLAggregateCast: {
     Expr *Src = E->getSubExpr();
     QualType SrcTy = Src->getType();
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 3809e3b1db3494..eb01dfc786b134 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -2782,6 +2782,22 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
     llvm::Value *Zero = llvm::Constant::getNullValue(CGF.SizeTy);
     return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc");
   }
+  case CK_HLSLSplatCast: {
+    assert(DestTy->isVectorType() && "Destination type must be a vector.");
+    auto *DestVecTy = DestTy->getAs<VectorType>();
+    QualType SrcTy = E->getType();
+    SourceLocation Loc = CE->getExprLoc();
+    Value *V = Visit(const_cast<Expr *>(E));
+    if (auto *VecTy = SrcTy->getAs<VectorType>()) {
+      assert(VecTy->getNumElements() == 1 && "Invalid HLSL splat cast.");
+      V = CGF.Builder.CreateExtractElement(V, (uint64_t)0, "vec.load");
+      SrcTy = VecTy->getElementType();
+    }
+    assert(SrcTy->isScalarType() && "Invalid HLSL splat cast.");
+    Value *Cast = EmitScalarConversion(V, SrcTy,
+				       DestVecTy->getElementType(), Loc);
+    return Builder.CreateVectorSplat(DestVecTy->getNumElements(), Cast, "splat");
+  }
   case CK_HLSLAggregateCast: {
     RValue RV = CGF.EmitAnyExpr(E);
     SourceLocation Loc = CE->getExprLoc();
diff --git a/clang/lib/Sema/Sema.cpp b/clang/lib/Sema/Sema.cpp
index 2f0528d6ab5ce1..7ba448b8ede5aa 100644
--- a/clang/lib/Sema/Sema.cpp
+++ b/clang/lib/Sema/Sema.cpp
@@ -708,6 +708,7 @@ ExprResult Sema::ImpCastExprToType(Expr *E, QualType Ty,
     case CK_NonAtomicToAtomic:
     case CK_HLSLArrayRValue:
     case CK_HLSLAggregateCast:
+    case CK_HLSLSplatCast:
       break;
     }
   }
diff --git a/clang/lib/Sema/SemaCast.cpp b/clang/lib/Sema/SemaCast.cpp
index 0bd7fc91aee18f..4f8e9a1899ae80 100644
--- a/clang/lib/Sema/SemaCast.cpp
+++ b/clang/lib/Sema/SemaCast.cpp
@@ -2772,11 +2772,13 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
   CheckedConversionKind CCK = FunctionalStyle
                                   ? CheckedConversionKind::FunctionalCast
                                   : CheckedConversionKind::CStyleCast;
-  // todo what else should i be doing lvalue to rvalue cast for?
-  // why dont they do it for records below?
-  // This case should not trigger on regular vector splat
-  // Or vector cast or vector truncation.
   QualType SrcTy = SrcExpr.get()->getType();
+  if (Self.getLangOpts().HLSL &&
+      Self.HLSL().CanPerformSplat(SrcExpr.get(), DestType)) {
+    Kind = CK_HLSLSplatCast;
+    return;
+  }
+
   if (Self.getLangOpts().HLSL &&
       Self.HLSL().CanPerformAggregateCast(SrcExpr.get(), DestType)) {
     if (SrcTy->isConstantArrayType())
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 5c7af8056063ad..1680475017f78e 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2522,6 +2522,32 @@ bool SemaHLSL::CanPerformScalarCast(QualType SrcTy, QualType DestTy) {
   llvm_unreachable("Unhandled scalar cast");
 }
 
+// Can perform an HLSL splat cast if the Dest is an aggregate and the
+// Src is a scalar or a vector of length 1
+// Or if Dest is a vector and Src is a vector of length 1
+bool SemaHLSL::CanPerformSplat(Expr *Src, QualType DestTy) {
+
+  QualType SrcTy = Src->getType();
+  if (SrcTy->isScalarType() && DestTy->isVectorType())
+    return false;
+
+  const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();
+  if (!(SrcTy->isScalarType() || (SrcVecTy && SrcVecTy->getNumElements() == 1)))
+    return false;
+
+  if (SrcVecTy)
+    SrcTy = SrcVecTy->getElementType();
+
+  llvm::SmallVector<QualType> DestTypes;
+  BuildFlattenedTypeList(DestTy, DestTypes);
+
+  for(unsigned i = 0; i < DestTypes.size(); i ++) {
+    if (!CanPerformScalarCast(SrcTy, DestTypes[i]))
+      return false;
+  }
+  return true;
+}
+
 // Can we perform an HLSL Flattened cast?
 bool SemaHLSL::CanPerformAggregateCast(Expr *Src, QualType DestTy) {
 

>From 2f2da8b23c1b951fd871aa7ef2d9dfaad2c6d6d8 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Fri, 6 Dec 2024 05:19:00 +0000
Subject: [PATCH 6/8] make clang format happy

---
 clang/lib/CodeGen/CGExprAgg.cpp    | 19 ++++++++-----------
 clang/lib/CodeGen/CGExprScalar.cpp |  7 ++++---
 clang/lib/Sema/SemaHLSL.cpp        |  2 +-
 3 files changed, 13 insertions(+), 15 deletions(-)

diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index 2642fe8328057d..175ccede228a8d 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -492,28 +492,25 @@ static bool isTrivialFiller(Expr *E) {
 }
 
 static void EmitHLSLSplatCast(CodeGenFunction &CGF, Address DestVal,
-			      QualType DestTy, llvm::Value *SrcVal,
-			      QualType SrcTy, SourceLocation Loc) {
+                              QualType DestTy, llvm::Value *SrcVal,
+                              QualType SrcTy, SourceLocation Loc) {
   // Flatten our destination
   SmallVector<QualType> DestTypes; // Flattened type
   SmallVector<llvm::Value *, 4> IdxList;
   SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
   // ^^ Flattened accesses to DestVal we want to store into
-  CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList,
-		       DestTypes);
+  CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList, DestTypes);
 
   if (const VectorType *VT = SrcTy->getAs<VectorType>()) {
     assert(VT->getNumElements() == 1 && "Invalid HLSL splat cast.");
 
     SrcTy = VT->getElementType();
-    SrcVal = CGF.Builder.CreateExtractElement(SrcVal, (uint64_t)0,
-					      "vec.load");
+    SrcVal = CGF.Builder.CreateExtractElement(SrcVal, (uint64_t)0, "vec.load");
   }
   assert(SrcTy->isScalarType() && "Invalid HLSL splat cast.");
-  for(unsigned i = 0; i < StoreGEPList.size(); i ++) {
-    llvm::Value *Cast = CGF.EmitScalarConversion(SrcVal, SrcTy,
-						 DestTypes[i],
-						 Loc);
+  for (unsigned i = 0; i < StoreGEPList.size(); i++) {
+    llvm::Value *Cast =
+        CGF.EmitScalarConversion(SrcVal, SrcTy, DestTypes[i], Loc);
     CGF.PerformStore(StoreGEPList[i], Cast);
   }
 }
@@ -983,7 +980,7 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
     QualType DestTy = E->getType();
     Address DestVal = Dest.getAddress();
     SourceLocation Loc = E->getExprLoc();
-    
+
     if (RV.isScalar()) {
       llvm::Value *SrcVal = RV.getScalarVal();
       EmitHLSLSplatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index eb01dfc786b134..1b352bc2e8abe7 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -2794,9 +2794,10 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
       SrcTy = VecTy->getElementType();
     }
     assert(SrcTy->isScalarType() && "Invalid HLSL splat cast.");
-    Value *Cast = EmitScalarConversion(V, SrcTy,
-				       DestVecTy->getElementType(), Loc);
-    return Builder.CreateVectorSplat(DestVecTy->getNumElements(), Cast, "splat");
+    Value *Cast =
+        EmitScalarConversion(V, SrcTy, DestVecTy->getElementType(), Loc);
+    return Builder.CreateVectorSplat(DestVecTy->getNumElements(), Cast,
+                                     "splat");
   }
   case CK_HLSLAggregateCast: {
     RValue RV = CGF.EmitAnyExpr(E);
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 1680475017f78e..0484f53186219f 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2541,7 +2541,7 @@ bool SemaHLSL::CanPerformSplat(Expr *Src, QualType DestTy) {
   llvm::SmallVector<QualType> DestTypes;
   BuildFlattenedTypeList(DestTy, DestTypes);
 
-  for(unsigned i = 0; i < DestTypes.size(); i ++) {
+  for (unsigned i = 0; i < DestTypes.size(); i++) {
     if (!CanPerformScalarCast(SrcTy, DestTypes[i]))
       return false;
   }

>From d6c200fd86d05a64e99536f0e1d067848d8b837c Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Fri, 6 Dec 2024 05:59:12 +0000
Subject: [PATCH 7/8] codegen test

---
 .../CodeGenHLSL/BasicFeatures/SplatCast.hlsl  | 87 +++++++++++++++++++
 1 file changed, 87 insertions(+)
 create mode 100644 clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl

diff --git a/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl
new file mode 100644
index 00000000000000..05359c1bce0ba3
--- /dev/null
+++ b/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl
@@ -0,0 +1,87 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s
+
+// array splat
+// CHECK-LABEL: define void {{.*}}call4
+// CHECK: [[B:%.*]] = alloca [2 x i32], align 4
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false)
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 1
+// CHECK-NEXT: store i32 3, ptr [[G1]], align 4
+// CHECK-NEXT: store i32 3, ptr [[G2]], align 4
+export void call4() {
+  int B[2] = {1,2};
+  B = (int[2])3;
+}
+
+// splat from vector of length 1
+// CHECK-LABEL: define void {{.*}}call8
+// CHECK: [[A:%.*]] = alloca <1 x i32>, align 4
+// CHECK-NEXT: [[B:%.*]] = alloca [2 x i32], align 4
+// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false)
+// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 1
+// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0
+// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
+// CHECK-NEXT: store i32 [[VL]], ptr [[G2]], align 4
+export void call8() {
+  int1 A = {1};
+  int B[2] = {1,2};
+  B = (int[2])A;
+}
+
+// vector splat from vector of length 1
+// CHECK-LABEL: define void {{.*}}call1
+// CHECK: [[B:%.*]] = alloca <1 x float>, align 4
+// CHECK-NEXT: [[A:%.*]] = alloca <4 x i32>, align 16
+// CHECK-NEXT: store <1 x float> splat (float 1.000000e+00), ptr [[B]], align 4
+// CHECK-NEXT: [[L:%.*]] = load <1 x float>, ptr [[B]], align 4
+// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x float> [[L]], i64 0
+// CHECK-NEXT: [[C:%.*]] = fptosi float [[VL]] to i32
+// CHECK-NEXT: [[SI:%.*]] = insertelement <4 x i32> poison, i32 [[C]], i64 0
+// CHECK-NEXT: [[S:%.*]] = shufflevector <4 x i32> [[SI]], <4 x i32> poison, <4 x i32> zeroinitializer
+// CHECK-NEXT: store <4 x i32> [[S]], ptr [[A]], align 16
+export void call1() {
+  float1 B = {1.0};
+  int4 A = (int4)B;
+}
+
+struct S {
+  int X;
+  float Y;
+};
+
+// struct splats?
+// CHECK-LABEL: define void {{.*}}call3
+// CHECK: [[A:%.*]] = alloca <1 x i32>, align 4
+// CHECK: [[s:%.*]] = alloca %struct.S, align 4
+// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
+// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 1
+// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0
+// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
+// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float
+// CHECK-NEXT: store float [[C]], ptr [[G2]], align 4
+export void call3() {
+  int1 A = {1};
+  S s = (S)A;
+}
+
+// struct splat from vector of length 1
+// CHECK-LABEL: define void {{.*}}call5
+// CHECK: [[A:%.*]] = alloca <1 x i32>, align 4
+// CHECK-NEXT: [[s:%.*]] = alloca %struct.S, align 4
+// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
+// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 1
+// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0
+// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
+// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float
+// CHECK-NEXT: store float [[C]], ptr [[G2]], align 4
+export void call5() {
+  int1 A = {1};
+  S s = (S)A;
+}

>From a74daceacc385f6bf4c876ff1560a544d9cb16ef Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Fri, 6 Dec 2024 17:38:58 +0000
Subject: [PATCH 8/8] Try to handle Cast in all the places it needs to be
 handled

---
 clang/lib/AST/Expr.cpp                        | 1 +
 clang/lib/AST/ExprConstant.cpp                | 2 ++
 clang/lib/CodeGen/CGExpr.cpp                  | 1 +
 clang/lib/CodeGen/CGExprAgg.cpp               | 1 +
 clang/lib/CodeGen/CGExprComplex.cpp           | 1 +
 clang/lib/CodeGen/CGExprConstant.cpp          | 1 +
 clang/lib/Edit/RewriteObjCFoundationAPI.cpp   | 1 +
 clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp | 1 +
 8 files changed, 9 insertions(+)

diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp
index 4764bc84ce498a..e5e3b073ee08a8 100644
--- a/clang/lib/AST/Expr.cpp
+++ b/clang/lib/AST/Expr.cpp
@@ -1943,6 +1943,7 @@ bool CastExpr::CastConsistency() const {
   case CK_HLSLArrayRValue:
   case CK_HLSLVectorTruncation:
   case CK_HLSLAggregateCast:
+  case CK_HLSLSplatCast:
   CheckNoBasePath:
     assert(path_empty() && "Cast kind should not have a base path!");
     break;
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index b548cef41b7525..45b4585ffa3867 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -14857,6 +14857,7 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
   case CK_FixedPointCast:
   case CK_IntegralToFixedPoint:
   case CK_MatrixCast:
+  case CK_HLSLSplatCast:
     llvm_unreachable("invalid cast kind for integral value");
 
   case CK_BitCast:
@@ -15734,6 +15735,7 @@ bool ComplexExprEvaluator::VisitCastExpr(const CastExpr *E) {
   case CK_MatrixCast:
   case CK_HLSLVectorTruncation:
   case CK_HLSLAggregateCast:
+  case CK_HLSLSplatCast:
     llvm_unreachable("invalid cast kind for complex value");
 
   case CK_LValueToRValue:
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 6b9c437ef7e242..462a6e0fbc23a2 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -5321,6 +5321,7 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) {
   case CK_HLSLVectorTruncation:
   case CK_HLSLArrayRValue:
   case CK_HLSLAggregateCast:
+    // TODO is CK_HLSLSplatCast an lvalue?
     return EmitUnsupportedLValue(E, "unexpected cast lvalue");
 
   case CK_Dependent:
diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index 175ccede228a8d..30f889c78f4016 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -1577,6 +1577,7 @@ static bool castPreservesZero(const CastExpr *CE) {
   case CK_NonAtomicToAtomic:
   case CK_AtomicToNonAtomic:
   case CK_HLSLVectorTruncation:
+    // TODO is this true for CK_HLSLSplatCast
     return true;
 
   case CK_BaseToDerivedMemberPointer:
diff --git a/clang/lib/CodeGen/CGExprComplex.cpp b/clang/lib/CodeGen/CGExprComplex.cpp
index 05680d36aa2bd7..91e06f9d0ea6e2 100644
--- a/clang/lib/CodeGen/CGExprComplex.cpp
+++ b/clang/lib/CodeGen/CGExprComplex.cpp
@@ -611,6 +611,7 @@ ComplexPairTy ComplexExprEmitter::EmitCast(CastKind CK, Expr *Op,
   case CK_HLSLVectorTruncation:
   case CK_HLSLArrayRValue:
   case CK_HLSLAggregateCast:
+  case CK_HLSLSplatCast:
     llvm_unreachable("invalid cast kind for complex value");
 
   case CK_FloatingRealToComplex:
diff --git a/clang/lib/CodeGen/CGExprConstant.cpp b/clang/lib/CodeGen/CGExprConstant.cpp
index 6d15bc9058e450..c66691c3f98261 100644
--- a/clang/lib/CodeGen/CGExprConstant.cpp
+++ b/clang/lib/CodeGen/CGExprConstant.cpp
@@ -1336,6 +1336,7 @@ class ConstExprEmitter
     case CK_HLSLVectorTruncation:
     case CK_HLSLArrayRValue:
     case CK_HLSLAggregateCast:
+    case CK_HLSLSplatCast:
       return nullptr;
     }
     llvm_unreachable("Invalid CastKind");
diff --git a/clang/lib/Edit/RewriteObjCFoundationAPI.cpp b/clang/lib/Edit/RewriteObjCFoundationAPI.cpp
index 63308319a78d1c..59b9c8d7c8da4a 100644
--- a/clang/lib/Edit/RewriteObjCFoundationAPI.cpp
+++ b/clang/lib/Edit/RewriteObjCFoundationAPI.cpp
@@ -1086,6 +1086,7 @@ static bool rewriteToNumericBoxedExpression(const ObjCMessageExpr *Msg,
 
     case CK_HLSLVectorTruncation:
     case CK_HLSLAggregateCast:
+    case CK_HLSLSplatCast:
       llvm_unreachable("HLSL-specific cast in Objective-C?");
       break;
 
diff --git a/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp b/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp
index b105c196fc3bfb..d8780c52221508 100644
--- a/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp
+++ b/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp
@@ -523,6 +523,7 @@ void ExprEngine::VisitCast(const CastExpr *CastE, const Expr *Ex,
       case CK_MatrixCast:
       case CK_VectorSplat:
       case CK_HLSLAggregateCast:
+      case CK_HLSLSplatCast:
       case CK_HLSLVectorTruncation: {
         QualType resultType = CastE->getType();
         if (CastE->isGLValue())



More information about the cfe-commits mailing list