[clang] [HLSL] Implement HLSL splatting (PR #118992)

Sarah Spall via cfe-commits cfe-commits at lists.llvm.org
Thu Feb 13 15:25:40 PST 2025


https://github.com/spall updated https://github.com/llvm/llvm-project/pull/118992

>From e994824f3630ee8b224afceb6c14d980c9013112 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Fri, 6 Dec 2024 05:14:17 +0000
Subject: [PATCH 01/19] splat cast wip

---
 clang/include/clang/AST/OperationKinds.def |  3 ++
 clang/include/clang/Sema/SemaHLSL.h        |  1 +
 clang/lib/CodeGen/CGExprAgg.cpp            | 42 ++++++++++++++++++++++
 clang/lib/CodeGen/CGExprScalar.cpp         | 16 +++++++++
 clang/lib/Sema/Sema.cpp                    |  1 +
 clang/lib/Sema/SemaCast.cpp                |  9 ++++-
 clang/lib/Sema/SemaHLSL.cpp                | 26 ++++++++++++++
 7 files changed, 97 insertions(+), 1 deletion(-)

diff --git a/clang/include/clang/AST/OperationKinds.def b/clang/include/clang/AST/OperationKinds.def
index b3dc7c3d8dc77..333fc7e1b1882 100644
--- a/clang/include/clang/AST/OperationKinds.def
+++ b/clang/include/clang/AST/OperationKinds.def
@@ -370,6 +370,9 @@ CAST_OPERATION(HLSLArrayRValue)
 // Aggregate by Value cast (HLSL only).
 CAST_OPERATION(HLSLElementwiseCast)
 
+// Splat cast for Aggregates (HLSL only).
+CAST_OPERATION(HLSLSplatCast)
+
 //===- Binary Operations  -------------------------------------------------===//
 // Operators listed in order of precedence.
 // Note that additions to this should also update the StmtVisitor class,
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index 6e8ca2e4710de..7508b149b0d81 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -144,6 +144,7 @@ class SemaHLSL : public SemaBase {
   bool CanPerformScalarCast(QualType SrcTy, QualType DestTy);
   bool ContainsBitField(QualType BaseTy);
   bool CanPerformElementwiseCast(Expr *Src, QualType DestType);
+  bool CanPerformSplat(Expr *Src, QualType DestType);
   ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg);
 
   QualType getInoutParameterType(QualType Ty);
diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index c3f1cbed6b39f..f26189bc4907c 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -491,6 +491,33 @@ static bool isTrivialFiller(Expr *E) {
   return false;
 }
 
+static void EmitHLSLSplatCast(CodeGenFunction &CGF, Address DestVal,
+			      QualType DestTy, llvm::Value *SrcVal,
+			      QualType SrcTy, SourceLocation Loc) {
+  // Flatten our destination
+  SmallVector<QualType> DestTypes; // Flattened type
+  SmallVector<llvm::Value *, 4> IdxList;
+  SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
+  // ^^ Flattened accesses to DestVal we want to store into
+  CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList,
+		       DestTypes);
+
+  if (const VectorType *VT = SrcTy->getAs<VectorType>()) {
+    assert(VT->getNumElements() == 1 && "Invalid HLSL splat cast.");
+
+    SrcTy = VT->getElementType();
+    SrcVal = CGF.Builder.CreateExtractElement(SrcVal, (uint64_t)0,
+					      "vec.load");
+  }
+  assert(SrcTy->isScalarType() && "Invalid HLSL splat cast.");
+  for(unsigned i = 0; i < StoreGEPList.size(); i ++) {
+    llvm::Value *Cast = CGF.EmitScalarConversion(SrcVal, SrcTy,
+						 DestTypes[i],
+						 Loc);
+    CGF.PerformStore(StoreGEPList[i], Cast);
+  }
+}
+
 // emit a flat cast where the RHS is a scalar, including vector
 static void EmitHLSLScalarFlatCast(CodeGenFunction &CGF, Address DestVal,
                                    QualType DestTy, llvm::Value *SrcVal,
@@ -963,6 +990,21 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
   case CK_HLSLArrayRValue:
     Visit(E->getSubExpr());
     break;
+  case CK_HLSLSplatCast: {
+    Expr *Src = E->getSubExpr();
+    QualType SrcTy = Src->getType();
+    RValue RV = CGF.EmitAnyExpr(Src);
+    QualType DestTy = E->getType();
+    Address DestVal = Dest.getAddress();
+    SourceLocation Loc = E->getExprLoc();
+    
+    if (RV.isScalar()) {
+      llvm::Value *SrcVal = RV.getScalarVal();
+      EmitHLSLSplatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
+      break;
+    }
+    llvm_unreachable("RHS of HLSL splat cast must be a scalar or vector.");
+  }
   case CK_HLSLElementwiseCast: {
     Expr *Src = E->getSubExpr();
     QualType SrcTy = Src->getType();
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 80daed7e53951..7dc2682bae42f 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -2795,6 +2795,22 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
     llvm::Value *Zero = llvm::Constant::getNullValue(CGF.SizeTy);
     return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc");
   }
+  case CK_HLSLSplatCast: {
+    assert(DestTy->isVectorType() && "Destination type must be a vector.");
+    auto *DestVecTy = DestTy->getAs<VectorType>();
+    QualType SrcTy = E->getType();
+    SourceLocation Loc = CE->getExprLoc();
+    Value *V = Visit(const_cast<Expr *>(E));
+    if (auto *VecTy = SrcTy->getAs<VectorType>()) {
+      assert(VecTy->getNumElements() == 1 && "Invalid HLSL splat cast.");
+      V = CGF.Builder.CreateExtractElement(V, (uint64_t)0, "vec.load");
+      SrcTy = VecTy->getElementType();
+    }
+    assert(SrcTy->isScalarType() && "Invalid HLSL splat cast.");
+    Value *Cast = EmitScalarConversion(V, SrcTy,
+				       DestVecTy->getElementType(), Loc);
+    return Builder.CreateVectorSplat(DestVecTy->getNumElements(), Cast, "splat");
+  }
   case CK_HLSLElementwiseCast: {
     RValue RV = CGF.EmitAnyExpr(E);
     SourceLocation Loc = CE->getExprLoc();
diff --git a/clang/lib/Sema/Sema.cpp b/clang/lib/Sema/Sema.cpp
index 15c18f9a4525b..9eeefbb3c0023 100644
--- a/clang/lib/Sema/Sema.cpp
+++ b/clang/lib/Sema/Sema.cpp
@@ -709,6 +709,7 @@ ExprResult Sema::ImpCastExprToType(Expr *E, QualType Ty,
     case CK_ToVoid:
     case CK_NonAtomicToAtomic:
     case CK_HLSLArrayRValue:
+    case CK_HLSLSplatCast:
       break;
     }
   }
diff --git a/clang/lib/Sema/SemaCast.cpp b/clang/lib/Sema/SemaCast.cpp
index 23be71ad8e2ae..56d8396b1e9d4 100644
--- a/clang/lib/Sema/SemaCast.cpp
+++ b/clang/lib/Sema/SemaCast.cpp
@@ -2776,9 +2776,16 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
   CheckedConversionKind CCK = FunctionalStyle
                                   ? CheckedConversionKind::FunctionalCast
                                   : CheckedConversionKind::CStyleCast;
+
   // This case should not trigger on regular vector splat
-  // vector cast, vector truncation, or special hlsl splat cases
   QualType SrcTy = SrcExpr.get()->getType();
+  if (Self.getLangOpts().HLSL &&
+      Self.HLSL().CanPerformSplat(SrcExpr.get(), DestType)) {
+    Kind = CK_HLSLSplatCast;
+    return;
+  }
+
+  // This case should not trigger on regular vector cast, vector truncation
   if (Self.getLangOpts().HLSL &&
       Self.HLSL().CanPerformElementwiseCast(SrcExpr.get(), DestType)) {
     if (SrcTy->isConstantArrayType())
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index ec6b5b45de42b..7c9365787fd4f 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2804,6 +2804,32 @@ bool SemaHLSL::ContainsBitField(QualType BaseTy) {
   return false;
 }
 
+// Can perform an HLSL splat cast if the Dest is an aggregate and the
+// Src is a scalar or a vector of length 1
+// Or if Dest is a vector and Src is a vector of length 1
+bool SemaHLSL::CanPerformSplat(Expr *Src, QualType DestTy) {
+
+  QualType SrcTy = Src->getType();
+  if (SrcTy->isScalarType() && DestTy->isVectorType())
+    return false;
+
+  const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();
+  if (!(SrcTy->isScalarType() || (SrcVecTy && SrcVecTy->getNumElements() == 1)))
+    return false;
+
+  if (SrcVecTy)
+    SrcTy = SrcVecTy->getElementType();
+
+  llvm::SmallVector<QualType> DestTypes;
+  BuildFlattenedTypeList(DestTy, DestTypes);
+
+  for(unsigned i = 0; i < DestTypes.size(); i ++) {
+    if (!CanPerformScalarCast(SrcTy, DestTypes[i]))
+      return false;
+  }
+  return true;
+}
+
 // Can we perform an HLSL Elementwise cast?
 // TODO: update this code when matrices are added; see issue #88060
 bool SemaHLSL::CanPerformElementwiseCast(Expr *Src, QualType DestTy) {

>From 24bea86dd7a2c39ca9f21480990236dc44df8cf3 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Fri, 6 Dec 2024 05:19:00 +0000
Subject: [PATCH 02/19] make clang format happy

---
 clang/lib/CodeGen/CGExprAgg.cpp    | 19 ++++++++-----------
 clang/lib/CodeGen/CGExprScalar.cpp |  7 ++++---
 clang/lib/Sema/SemaHLSL.cpp        |  2 +-
 3 files changed, 13 insertions(+), 15 deletions(-)

diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index f26189bc4907c..60beabf3a5fd0 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -492,28 +492,25 @@ static bool isTrivialFiller(Expr *E) {
 }
 
 static void EmitHLSLSplatCast(CodeGenFunction &CGF, Address DestVal,
-			      QualType DestTy, llvm::Value *SrcVal,
-			      QualType SrcTy, SourceLocation Loc) {
+                              QualType DestTy, llvm::Value *SrcVal,
+                              QualType SrcTy, SourceLocation Loc) {
   // Flatten our destination
   SmallVector<QualType> DestTypes; // Flattened type
   SmallVector<llvm::Value *, 4> IdxList;
   SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
   // ^^ Flattened accesses to DestVal we want to store into
-  CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList,
-		       DestTypes);
+  CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList, DestTypes);
 
   if (const VectorType *VT = SrcTy->getAs<VectorType>()) {
     assert(VT->getNumElements() == 1 && "Invalid HLSL splat cast.");
 
     SrcTy = VT->getElementType();
-    SrcVal = CGF.Builder.CreateExtractElement(SrcVal, (uint64_t)0,
-					      "vec.load");
+    SrcVal = CGF.Builder.CreateExtractElement(SrcVal, (uint64_t)0, "vec.load");
   }
   assert(SrcTy->isScalarType() && "Invalid HLSL splat cast.");
-  for(unsigned i = 0; i < StoreGEPList.size(); i ++) {
-    llvm::Value *Cast = CGF.EmitScalarConversion(SrcVal, SrcTy,
-						 DestTypes[i],
-						 Loc);
+  for (unsigned i = 0; i < StoreGEPList.size(); i++) {
+    llvm::Value *Cast =
+        CGF.EmitScalarConversion(SrcVal, SrcTy, DestTypes[i], Loc);
     CGF.PerformStore(StoreGEPList[i], Cast);
   }
 }
@@ -997,7 +994,7 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
     QualType DestTy = E->getType();
     Address DestVal = Dest.getAddress();
     SourceLocation Loc = E->getExprLoc();
-    
+
     if (RV.isScalar()) {
       llvm::Value *SrcVal = RV.getScalarVal();
       EmitHLSLSplatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 7dc2682bae42f..4a20b693b101f 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -2807,9 +2807,10 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
       SrcTy = VecTy->getElementType();
     }
     assert(SrcTy->isScalarType() && "Invalid HLSL splat cast.");
-    Value *Cast = EmitScalarConversion(V, SrcTy,
-				       DestVecTy->getElementType(), Loc);
-    return Builder.CreateVectorSplat(DestVecTy->getNumElements(), Cast, "splat");
+    Value *Cast =
+        EmitScalarConversion(V, SrcTy, DestVecTy->getElementType(), Loc);
+    return Builder.CreateVectorSplat(DestVecTy->getNumElements(), Cast,
+                                     "splat");
   }
   case CK_HLSLElementwiseCast: {
     RValue RV = CGF.EmitAnyExpr(E);
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 7c9365787fd4f..024f778f8ffef 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2823,7 +2823,7 @@ bool SemaHLSL::CanPerformSplat(Expr *Src, QualType DestTy) {
   llvm::SmallVector<QualType> DestTypes;
   BuildFlattenedTypeList(DestTy, DestTypes);
 
-  for(unsigned i = 0; i < DestTypes.size(); i ++) {
+  for (unsigned i = 0; i < DestTypes.size(); i++) {
     if (!CanPerformScalarCast(SrcTy, DestTypes[i]))
       return false;
   }

>From 3575617d436f04eac4faadc17ead8bfe561e7e7c Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Fri, 6 Dec 2024 05:59:12 +0000
Subject: [PATCH 03/19] codegen test

---
 .../CodeGenHLSL/BasicFeatures/SplatCast.hlsl  | 87 +++++++++++++++++++
 1 file changed, 87 insertions(+)
 create mode 100644 clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl

diff --git a/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl
new file mode 100644
index 0000000000000..05359c1bce0ba
--- /dev/null
+++ b/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl
@@ -0,0 +1,87 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s
+
+// array splat
+// CHECK-LABEL: define void {{.*}}call4
+// CHECK: [[B:%.*]] = alloca [2 x i32], align 4
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false)
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 1
+// CHECK-NEXT: store i32 3, ptr [[G1]], align 4
+// CHECK-NEXT: store i32 3, ptr [[G2]], align 4
+export void call4() {
+  int B[2] = {1,2};
+  B = (int[2])3;
+}
+
+// splat from vector of length 1
+// CHECK-LABEL: define void {{.*}}call8
+// CHECK: [[A:%.*]] = alloca <1 x i32>, align 4
+// CHECK-NEXT: [[B:%.*]] = alloca [2 x i32], align 4
+// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false)
+// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 1
+// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0
+// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
+// CHECK-NEXT: store i32 [[VL]], ptr [[G2]], align 4
+export void call8() {
+  int1 A = {1};
+  int B[2] = {1,2};
+  B = (int[2])A;
+}
+
+// vector splat from vector of length 1
+// CHECK-LABEL: define void {{.*}}call1
+// CHECK: [[B:%.*]] = alloca <1 x float>, align 4
+// CHECK-NEXT: [[A:%.*]] = alloca <4 x i32>, align 16
+// CHECK-NEXT: store <1 x float> splat (float 1.000000e+00), ptr [[B]], align 4
+// CHECK-NEXT: [[L:%.*]] = load <1 x float>, ptr [[B]], align 4
+// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x float> [[L]], i64 0
+// CHECK-NEXT: [[C:%.*]] = fptosi float [[VL]] to i32
+// CHECK-NEXT: [[SI:%.*]] = insertelement <4 x i32> poison, i32 [[C]], i64 0
+// CHECK-NEXT: [[S:%.*]] = shufflevector <4 x i32> [[SI]], <4 x i32> poison, <4 x i32> zeroinitializer
+// CHECK-NEXT: store <4 x i32> [[S]], ptr [[A]], align 16
+export void call1() {
+  float1 B = {1.0};
+  int4 A = (int4)B;
+}
+
+struct S {
+  int X;
+  float Y;
+};
+
+// struct splats?
+// CHECK-LABEL: define void {{.*}}call3
+// CHECK: [[A:%.*]] = alloca <1 x i32>, align 4
+// CHECK: [[s:%.*]] = alloca %struct.S, align 4
+// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
+// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 1
+// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0
+// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
+// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float
+// CHECK-NEXT: store float [[C]], ptr [[G2]], align 4
+export void call3() {
+  int1 A = {1};
+  S s = (S)A;
+}
+
+// struct splat from vector of length 1
+// CHECK-LABEL: define void {{.*}}call5
+// CHECK: [[A:%.*]] = alloca <1 x i32>, align 4
+// CHECK-NEXT: [[s:%.*]] = alloca %struct.S, align 4
+// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
+// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 1
+// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0
+// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
+// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float
+// CHECK-NEXT: store float [[C]], ptr [[G2]], align 4
+export void call5() {
+  int1 A = {1};
+  S s = (S)A;
+}

>From 288b8dac1c6fa4429c92c566a69da593c2ebb97c Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Fri, 6 Dec 2024 17:38:58 +0000
Subject: [PATCH 04/19] Try to handle Cast in all the places it needs to be
 handled

---
 clang/lib/AST/Expr.cpp                        | 1 +
 clang/lib/AST/ExprConstant.cpp                | 2 ++
 clang/lib/CodeGen/CGExprAgg.cpp               | 1 +
 clang/lib/CodeGen/CGExprComplex.cpp           | 1 +
 clang/lib/CodeGen/CGExprConstant.cpp          | 1 +
 clang/lib/Edit/RewriteObjCFoundationAPI.cpp   | 1 +
 clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp | 1 +
 7 files changed, 8 insertions(+)

diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp
index c22aa66ba2cfb..bbb475fbb30f2 100644
--- a/clang/lib/AST/Expr.cpp
+++ b/clang/lib/AST/Expr.cpp
@@ -1957,6 +1957,7 @@ bool CastExpr::CastConsistency() const {
   case CK_HLSLArrayRValue:
   case CK_HLSLVectorTruncation:
   case CK_HLSLElementwiseCast:
+  case CK_HLSLSplatCast:
   CheckNoBasePath:
     assert(path_empty() && "Cast kind should not have a base path!");
     break;
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 192b679b4c995..ddc2d00883900 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -15029,6 +15029,7 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
   case CK_FixedPointCast:
   case CK_IntegralToFixedPoint:
   case CK_MatrixCast:
+  case CK_HLSLSplatCast:
     llvm_unreachable("invalid cast kind for integral value");
 
   case CK_BitCast:
@@ -15907,6 +15908,7 @@ bool ComplexExprEvaluator::VisitCastExpr(const CastExpr *E) {
   case CK_MatrixCast:
   case CK_HLSLVectorTruncation:
   case CK_HLSLElementwiseCast:
+  case CK_HLSLSplatCast:
     llvm_unreachable("invalid cast kind for complex value");
 
   case CK_LValueToRValue:
diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index 60beabf3a5fd0..3584280e2fb9e 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -1592,6 +1592,7 @@ static bool castPreservesZero(const CastExpr *CE) {
   case CK_AtomicToNonAtomic:
   case CK_HLSLVectorTruncation:
   case CK_HLSLElementwiseCast:
+    // TODO is this true for CK_HLSLSplatCast
     return true;
 
   case CK_BaseToDerivedMemberPointer:
diff --git a/clang/lib/CodeGen/CGExprComplex.cpp b/clang/lib/CodeGen/CGExprComplex.cpp
index c2679ea92dc97..3832b9b598b24 100644
--- a/clang/lib/CodeGen/CGExprComplex.cpp
+++ b/clang/lib/CodeGen/CGExprComplex.cpp
@@ -611,6 +611,7 @@ ComplexPairTy ComplexExprEmitter::EmitCast(CastKind CK, Expr *Op,
   case CK_HLSLVectorTruncation:
   case CK_HLSLArrayRValue:
   case CK_HLSLElementwiseCast:
+  case CK_HLSLSplatCast:
     llvm_unreachable("invalid cast kind for complex value");
 
   case CK_FloatingRealToComplex:
diff --git a/clang/lib/CodeGen/CGExprConstant.cpp b/clang/lib/CodeGen/CGExprConstant.cpp
index ef11798869d3b..b8ce83803b65f 100644
--- a/clang/lib/CodeGen/CGExprConstant.cpp
+++ b/clang/lib/CodeGen/CGExprConstant.cpp
@@ -1336,6 +1336,7 @@ class ConstExprEmitter
     case CK_HLSLVectorTruncation:
     case CK_HLSLArrayRValue:
     case CK_HLSLElementwiseCast:
+    case CK_HLSLSplatCast:
       return nullptr;
     }
     llvm_unreachable("Invalid CastKind");
diff --git a/clang/lib/Edit/RewriteObjCFoundationAPI.cpp b/clang/lib/Edit/RewriteObjCFoundationAPI.cpp
index 32f5ebb55155e..10d3f62fcd0a4 100644
--- a/clang/lib/Edit/RewriteObjCFoundationAPI.cpp
+++ b/clang/lib/Edit/RewriteObjCFoundationAPI.cpp
@@ -1086,6 +1086,7 @@ static bool rewriteToNumericBoxedExpression(const ObjCMessageExpr *Msg,
 
     case CK_HLSLVectorTruncation:
     case CK_HLSLElementwiseCast:
+    case CK_HLSLSplatCast:
       llvm_unreachable("HLSL-specific cast in Objective-C?");
       break;
 
diff --git a/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp b/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp
index 3a983421358c7..d75583f68eb6b 100644
--- a/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp
+++ b/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp
@@ -523,6 +523,7 @@ void ExprEngine::VisitCast(const CastExpr *CastE, const Expr *Ex,
       case CK_MatrixCast:
       case CK_VectorSplat:
       case CK_HLSLElementwiseCast:
+      case CK_HLSLSplatCast:
       case CK_HLSLVectorTruncation: {
         QualType resultType = CastE->getType();
         if (CastE->isGLValue())

>From 0650840642960d950d64e234e9641e34096a6c55 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Wed, 11 Dec 2024 20:54:39 +0000
Subject: [PATCH 05/19] get code compiling after rebase

---
 clang/lib/CodeGen/CGExprAgg.cpp | 13 ++++++++++---
 1 file changed, 10 insertions(+), 3 deletions(-)

diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index 3584280e2fb9e..3330cd03628f7 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -496,10 +496,9 @@ static void EmitHLSLSplatCast(CodeGenFunction &CGF, Address DestVal,
                               QualType SrcTy, SourceLocation Loc) {
   // Flatten our destination
   SmallVector<QualType> DestTypes; // Flattened type
-  SmallVector<llvm::Value *, 4> IdxList;
   SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
   // ^^ Flattened accesses to DestVal we want to store into
-  CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList, DestTypes);
+  CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes);
 
   if (const VectorType *VT = SrcTy->getAs<VectorType>()) {
     assert(VT->getNumElements() == 1 && "Invalid HLSL splat cast.");
@@ -511,7 +510,15 @@ static void EmitHLSLSplatCast(CodeGenFunction &CGF, Address DestVal,
   for (unsigned i = 0; i < StoreGEPList.size(); i++) {
     llvm::Value *Cast =
         CGF.EmitScalarConversion(SrcVal, SrcTy, DestTypes[i], Loc);
-    CGF.PerformStore(StoreGEPList[i], Cast);
+
+    // store back
+    llvm::Value *Idx = StoreGEPList[i].second;
+    if (Idx) {
+      llvm::Value *V =
+          CGF.Builder.CreateLoad(StoreGEPList[i].first, "load.for.insert");
+      Cast = CGF.Builder.CreateInsertElement(V, Cast, Idx);
+    }
+    CGF.Builder.CreateStore(Cast, StoreGEPList[i].first);
   }
 }
 

>From f924b13ada0c3344f3cc4f87a859f0ecd16705cb Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Thu, 12 Dec 2024 00:04:29 +0000
Subject: [PATCH 06/19] Self review

---
 clang/lib/CodeGen/CGExprScalar.cpp           | 15 +++++++-----
 clang/lib/Sema/SemaHLSL.cpp                  |  7 +++---
 clang/test/SemaHLSL/Language/SplatCasts.hlsl | 25 ++++++++++++++++++++
 3 files changed, 38 insertions(+), 9 deletions(-)
 create mode 100644 clang/test/SemaHLSL/Language/SplatCasts.hlsl

diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 4a20b693b101f..85c0265ea14b6 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -2796,17 +2796,20 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
     return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc");
   }
   case CK_HLSLSplatCast: {
+    // This code should only handle splatting from vectors of length 1.
     assert(DestTy->isVectorType() && "Destination type must be a vector.");
     auto *DestVecTy = DestTy->getAs<VectorType>();
     QualType SrcTy = E->getType();
     SourceLocation Loc = CE->getExprLoc();
     Value *V = Visit(const_cast<Expr *>(E));
-    if (auto *VecTy = SrcTy->getAs<VectorType>()) {
-      assert(VecTy->getNumElements() == 1 && "Invalid HLSL splat cast.");
-      V = CGF.Builder.CreateExtractElement(V, (uint64_t)0, "vec.load");
-      SrcTy = VecTy->getElementType();
-    }
-    assert(SrcTy->isScalarType() && "Invalid HLSL splat cast.");
+    assert(SrcTy->isVectorType() && "Invalid HLSL splat cast.");
+
+    auto *VecTy = SrcTy->getAs<VectorType>();
+    assert(VecTy->getNumElements() == 1 && "Invalid HLSL splat cast.");
+
+    V = CGF.Builder.CreateExtractElement(V, (uint64_t)0, "vec.load");
+    SrcTy = VecTy->getElementType();
+
     Value *Cast =
         EmitScalarConversion(V, SrcTy, DestVecTy->getElementType(), Loc);
     return Builder.CreateVectorSplat(DestVecTy->getNumElements(), Cast,
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 024f778f8ffef..432a42016789e 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2814,12 +2814,13 @@ bool SemaHLSL::CanPerformSplat(Expr *Src, QualType DestTy) {
     return false;
 
   const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();
-  if (!(SrcTy->isScalarType() || (SrcVecTy && SrcVecTy->getNumElements() == 1)))
-    return false;
-
   if (SrcVecTy)
     SrcTy = SrcVecTy->getElementType();
 
+  // Src isn't a scalar or a vector of length 1
+  if (!SrcTy->isScalarType() && !(SrcVecTy && SrcVecTy->getNumElements() == 1))
+    return false;
+
   llvm::SmallVector<QualType> DestTypes;
   BuildFlattenedTypeList(DestTy, DestTypes);
 
diff --git a/clang/test/SemaHLSL/Language/SplatCasts.hlsl b/clang/test/SemaHLSL/Language/SplatCasts.hlsl
new file mode 100644
index 0000000000000..593a8e67fd4a3
--- /dev/null
+++ b/clang/test/SemaHLSL/Language/SplatCasts.hlsl
@@ -0,0 +1,25 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -finclude-default-header -fnative-half-type %s -ast-dump | FileCheck %s
+
+// splat from vec1 to vec
+// CHECK-LABEL: call1
+// CHECK: CStyleCastExpr {{.*}} 'int3':'vector<int, 3>' <HLSLSplatCast>
+// CHECK-NEXT: DeclRefExpr {{.*}} 'float1':'vector<float, 1>' lvalue Var {{.*}} 'A' 'float1':'vector<float, 1>'
+export void call1() {
+  float1 A = {1.0};
+  int3 B = (int3)A;
+}
+
+struct S {
+  int A;
+  float B;
+  int C;
+  float D;
+};
+
+// splat from scalar to aggregate
+// CHECK-LABEL: call2
+// CHECK: CStyleCastExpr {{.*}} 'S' <HLSLSplatCast>
+// CHECK-NEXt: IntegerLiteral {{.*}} 'int' 5
+export void call2() {
+  S s = (S)5; 
+}
\ No newline at end of file

>From 89ceeb7d6b445f10fa6b7deb8c10267cd292da7b Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Thu, 12 Dec 2024 05:59:55 +0000
Subject: [PATCH 07/19] move code back that broke tests

---
 clang/lib/Sema/SemaHLSL.cpp | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 432a42016789e..76ca24b10c60a 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2814,13 +2814,14 @@ bool SemaHLSL::CanPerformSplat(Expr *Src, QualType DestTy) {
     return false;
 
   const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();
-  if (SrcVecTy)
-    SrcTy = SrcVecTy->getElementType();
 
   // Src isn't a scalar or a vector of length 1
   if (!SrcTy->isScalarType() && !(SrcVecTy && SrcVecTy->getNumElements() == 1))
     return false;
 
+  if (SrcVecTy)
+    SrcTy = SrcVecTy->getElementType();
+
   llvm::SmallVector<QualType> DestTypes;
   BuildFlattenedTypeList(DestTy, DestTypes);
 

>From 7f5b3e4f39f2a4cf2d42e5281e70d900878c1a3b Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Thu, 12 Dec 2024 06:08:46 +0000
Subject: [PATCH 08/19] fix tests

---
 .../CodeGenHLSL/BasicFeatures/SplatCast.hlsl     | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl
index 05359c1bce0ba..2de68479179dd 100644
--- a/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl
+++ b/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl
@@ -4,8 +4,8 @@
 // CHECK-LABEL: define void {{.*}}call4
 // CHECK: [[B:%.*]] = alloca [2 x i32], align 4
 // CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false)
-// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0
-// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 1
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 1
 // CHECK-NEXT: store i32 3, ptr [[G1]], align 4
 // CHECK-NEXT: store i32 3, ptr [[G2]], align 4
 export void call4() {
@@ -20,8 +20,8 @@ export void call4() {
 // CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
 // CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false)
 // CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
-// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0
-// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 1
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 1
 // CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0
 // CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
 // CHECK-NEXT: store i32 [[VL]], ptr [[G2]], align 4
@@ -58,8 +58,8 @@ struct S {
 // CHECK: [[s:%.*]] = alloca %struct.S, align 4
 // CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
 // CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
-// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0
-// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 1
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 1
 // CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0
 // CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
 // CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float
@@ -75,8 +75,8 @@ export void call3() {
 // CHECK-NEXT: [[s:%.*]] = alloca %struct.S, align 4
 // CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
 // CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
-// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0
-// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 1
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 1
 // CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0
 // CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
 // CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float

>From 844ba82eb5dcfbd0105db2d4943266fa8d009c17 Mon Sep 17 00:00:00 2001
From: Sarah Spall <sarahspall at microsoft.com>
Date: Sat, 8 Feb 2025 09:07:05 -0800
Subject: [PATCH 09/19] add cast to cases

---
 clang/lib/CodeGen/CGExpr.cpp    | 1 +
 clang/lib/CodeGen/CGExprAgg.cpp | 2 +-
 2 files changed, 2 insertions(+), 1 deletion(-)

diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 2bbc0791c6587..545d8b11a6a47 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -5339,6 +5339,7 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) {
   case CK_HLSLVectorTruncation:
   case CK_HLSLArrayRValue:
   case CK_HLSLElementwiseCast:
+  case CK_HLSLSplatCast:
     return EmitUnsupportedLValue(E, "unexpected cast lvalue");
 
   case CK_Dependent:
diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index 3330cd03628f7..b7fe62687b074 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -1599,7 +1599,7 @@ static bool castPreservesZero(const CastExpr *CE) {
   case CK_AtomicToNonAtomic:
   case CK_HLSLVectorTruncation:
   case CK_HLSLElementwiseCast:
-    // TODO is this true for CK_HLSLSplatCast
+  case CK_HLSLSplatCast:
     return true;
 
   case CK_BaseToDerivedMemberPointer:

>From 848315d47512a65ac98ecbb2c2102c6c4eef75f8 Mon Sep 17 00:00:00 2001
From: Sarah Spall <sarahspall at microsoft.com>
Date: Mon, 10 Feb 2025 12:21:59 -0800
Subject: [PATCH 10/19] self review

---
 clang/include/clang/Sema/SemaHLSL.h          |  2 +-
 clang/lib/CodeGen/CGExprAgg.cpp              | 26 +++++++-------------
 clang/lib/CodeGen/CGExprScalar.cpp           | 11 +++------
 clang/lib/Sema/SemaCast.cpp                  |  7 +++++-
 clang/lib/Sema/SemaHLSL.cpp                  |  2 +-
 clang/test/SemaHLSL/Language/SplatCasts.hlsl |  2 +-
 6 files changed, 21 insertions(+), 29 deletions(-)

diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index 7508b149b0d81..3772301afdd4f 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -144,7 +144,7 @@ class SemaHLSL : public SemaBase {
   bool CanPerformScalarCast(QualType SrcTy, QualType DestTy);
   bool ContainsBitField(QualType BaseTy);
   bool CanPerformElementwiseCast(Expr *Src, QualType DestType);
-  bool CanPerformSplat(Expr *Src, QualType DestType);
+  bool CanPerformSplatCast(Expr *Src, QualType DestType);
   ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg);
 
   QualType getInoutParameterType(QualType Ty);
diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index b7fe62687b074..36557ccb15f1a 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -500,25 +500,19 @@ static void EmitHLSLSplatCast(CodeGenFunction &CGF, Address DestVal,
   // ^^ Flattened accesses to DestVal we want to store into
   CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes);
 
-  if (const VectorType *VT = SrcTy->getAs<VectorType>()) {
-    assert(VT->getNumElements() == 1 && "Invalid HLSL splat cast.");
-
-    SrcTy = VT->getElementType();
-    SrcVal = CGF.Builder.CreateExtractElement(SrcVal, (uint64_t)0, "vec.load");
-  }
   assert(SrcTy->isScalarType() && "Invalid HLSL splat cast.");
-  for (unsigned i = 0; i < StoreGEPList.size(); i++) {
+  for (unsigned I = 0, Size = StoreGEPList.size(); I < Size; I++) {
     llvm::Value *Cast =
-        CGF.EmitScalarConversion(SrcVal, SrcTy, DestTypes[i], Loc);
+        CGF.EmitScalarConversion(SrcVal, SrcTy, DestTypes[I], Loc);
 
     // store back
-    llvm::Value *Idx = StoreGEPList[i].second;
+    llvm::Value *Idx = StoreGEPList[I].second;
     if (Idx) {
       llvm::Value *V =
-          CGF.Builder.CreateLoad(StoreGEPList[i].first, "load.for.insert");
+          CGF.Builder.CreateLoad(StoreGEPList[I].first, "load.for.insert");
       Cast = CGF.Builder.CreateInsertElement(V, Cast, Idx);
     }
-    CGF.Builder.CreateStore(Cast, StoreGEPList[i].first);
+    CGF.Builder.CreateStore(Cast, StoreGEPList[I].first);
   }
 }
 
@@ -1002,12 +996,10 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
     Address DestVal = Dest.getAddress();
     SourceLocation Loc = E->getExprLoc();
 
-    if (RV.isScalar()) {
-      llvm::Value *SrcVal = RV.getScalarVal();
-      EmitHLSLSplatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
-      break;
-    }
-    llvm_unreachable("RHS of HLSL splat cast must be a scalar or vector.");
+    assert (RV.isScalar() && "RHS of HLSL splat cast must be a scalar.");
+    llvm::Value *SrcVal = RV.getScalarVal();
+    EmitHLSLSplatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
+    break;
   }
   case CK_HLSLElementwiseCast: {
     Expr *Src = E->getSubExpr();
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 85c0265ea14b6..b09c18f4a1229 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -2796,19 +2796,14 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
     return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc");
   }
   case CK_HLSLSplatCast: {
-    // This code should only handle splatting from vectors of length 1.
+    // This cast should only handle splatting from vectors of length 1.
+    // But in Sema a cast should have been inserted to convert the vec1 to a scalar
     assert(DestTy->isVectorType() && "Destination type must be a vector.");
     auto *DestVecTy = DestTy->getAs<VectorType>();
     QualType SrcTy = E->getType();
     SourceLocation Loc = CE->getExprLoc();
     Value *V = Visit(const_cast<Expr *>(E));
-    assert(SrcTy->isVectorType() && "Invalid HLSL splat cast.");
-
-    auto *VecTy = SrcTy->getAs<VectorType>();
-    assert(VecTy->getNumElements() == 1 && "Invalid HLSL splat cast.");
-
-    V = CGF.Builder.CreateExtractElement(V, (uint64_t)0, "vec.load");
-    SrcTy = VecTy->getElementType();
+    assert(SrcTy->isBuiltinType() && "Invalid HLSL splat cast.");
 
     Value *Cast =
         EmitScalarConversion(V, SrcTy, DestVecTy->getElementType(), Loc);
diff --git a/clang/lib/Sema/SemaCast.cpp b/clang/lib/Sema/SemaCast.cpp
index 56d8396b1e9d4..a60bc0687461f 100644
--- a/clang/lib/Sema/SemaCast.cpp
+++ b/clang/lib/Sema/SemaCast.cpp
@@ -2780,7 +2780,12 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
   // This case should not trigger on regular vector splat
   QualType SrcTy = SrcExpr.get()->getType();
   if (Self.getLangOpts().HLSL &&
-      Self.HLSL().CanPerformSplat(SrcExpr.get(), DestType)) {
+      Self.HLSL().CanPerformSplatCast(SrcExpr.get(), DestType)) {
+    const VectorType *VT = SrcTy->getAs<VectorType>();
+    // change splat from vec1 case to splat from scalar
+    if (VT && VT->getNumElements() == 1)
+      SrcExpr = Self.ImpCastExprToType(SrcExpr.get(), VT->getElementType(),
+				       CK_HLSLVectorTruncation, VK_PRValue, nullptr, CCK);
     Kind = CK_HLSLSplatCast;
     return;
   }
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 76ca24b10c60a..d20bd281b7dc9 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2807,7 +2807,7 @@ bool SemaHLSL::ContainsBitField(QualType BaseTy) {
 // Can perform an HLSL splat cast if the Dest is an aggregate and the
 // Src is a scalar or a vector of length 1
 // Or if Dest is a vector and Src is a vector of length 1
-bool SemaHLSL::CanPerformSplat(Expr *Src, QualType DestTy) {
+bool SemaHLSL::CanPerformSplatCast(Expr *Src, QualType DestTy) {
 
   QualType SrcTy = Src->getType();
   if (SrcTy->isScalarType() && DestTy->isVectorType())
diff --git a/clang/test/SemaHLSL/Language/SplatCasts.hlsl b/clang/test/SemaHLSL/Language/SplatCasts.hlsl
index 593a8e67fd4a3..cfe3b981dc92c 100644
--- a/clang/test/SemaHLSL/Language/SplatCasts.hlsl
+++ b/clang/test/SemaHLSL/Language/SplatCasts.hlsl
@@ -22,4 +22,4 @@ struct S {
 // CHECK-NEXt: IntegerLiteral {{.*}} 'int' 5
 export void call2() {
   S s = (S)5; 
-}
\ No newline at end of file
+}

>From cadd309a7d7c61a592d11cb306d44139df8d15ee Mon Sep 17 00:00:00 2001
From: Sarah Spall <sarahspall at microsoft.com>
Date: Mon, 10 Feb 2025 15:26:45 -0800
Subject: [PATCH 11/19] disallow splatting things with bitvectors, add tests to
 show casting bitvectors not allowed. At test showing splatting union is not
 allowed. At test showing splatting union in elementwise cast is not allowed.

---
 clang/lib/Sema/SemaHLSL.cpp                   | 11 +++++--
 .../Language/ElementwiseCast-errors.hlsl      | 20 ++++++++++++
 .../SemaHLSL/Language/SplatCasts-errors.hlsl  | 32 +++++++++++++++++++
 3 files changed, 60 insertions(+), 3 deletions(-)
 create mode 100644 clang/test/SemaHLSL/Language/SplatCasts-errors.hlsl

diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index d20bd281b7dc9..252121f88af05 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2771,7 +2771,7 @@ bool SemaHLSL::CanPerformScalarCast(QualType SrcTy, QualType DestTy) {
 }
 
 // Detect if a type contains a bitfield. Will be removed when
-// bitfield support is added to HLSLElementwiseCast
+// bitfield support is added to HLSLElementwiseCast and HLSLSplatCast
 bool SemaHLSL::ContainsBitField(QualType BaseTy) {
   llvm::SmallVector<QualType, 16> WorkList;
   WorkList.push_back(BaseTy);
@@ -2822,11 +2822,16 @@ bool SemaHLSL::CanPerformSplatCast(Expr *Src, QualType DestTy) {
   if (SrcVecTy)
     SrcTy = SrcVecTy->getElementType();
 
+  if (ContainsBitField(DestTy))
+    return false;
+
   llvm::SmallVector<QualType> DestTypes;
   BuildFlattenedTypeList(DestTy, DestTypes);
 
-  for (unsigned i = 0; i < DestTypes.size(); i++) {
-    if (!CanPerformScalarCast(SrcTy, DestTypes[i]))
+  for (unsigned I = 0, Size = DestTypes.size(); I < Size; I++) {
+    if (DestTypes[I]->isUnionType())
+      return false;
+    if (!CanPerformScalarCast(SrcTy, DestTypes[I]))
       return false;
   }
   return true;
diff --git a/clang/test/SemaHLSL/Language/ElementwiseCast-errors.hlsl b/clang/test/SemaHLSL/Language/ElementwiseCast-errors.hlsl
index c900c83a063a0..b7085bc69547b 100644
--- a/clang/test/SemaHLSL/Language/ElementwiseCast-errors.hlsl
+++ b/clang/test/SemaHLSL/Language/ElementwiseCast-errors.hlsl
@@ -27,3 +27,23 @@ export void cantCast3() {
   S s = (S)C;
   // expected-error at -1 {{no matching conversion for C-style cast from 'int2' (aka 'vector<int, 2>') to 'S'}}
 }
+
+struct R {
+// expected-note at -1 {{candidate constructor (the implicit copy constructor) not viable: no known conversion from 'int2' (aka 'vector<int, 2>') to 'const R' for 1st argument}}
+// expected-note at -2 {{candidate constructor (the implicit move constructor) not viable: no known conversion from 'int2' (aka 'vector<int, 2>') to 'R' for 1st argument}}
+// expected-note at -3 {{candidate constructor (the implicit default constructor) not viable: requires 0 arguments, but 1 was provided}}
+  int A;
+  union {
+    float F;
+    int4 G;
+  };
+};
+
+export void cantCast4() {
+  int2 A = {1,2};
+  R r = R(A);
+  // expected-error at -1 {{no matching conversion for functional-style cast from 'int2' (aka 'vector<int, 2>') to 'R'}}
+  R r2 = {1, 2};
+  int2 B = (int2)r2;
+  // expected-error at -1 {{cannot convert 'R' to 'int2' (aka 'vector<int, 2>') without a conversion operator}}
+}
diff --git a/clang/test/SemaHLSL/Language/SplatCasts-errors.hlsl b/clang/test/SemaHLSL/Language/SplatCasts-errors.hlsl
new file mode 100644
index 0000000000000..b0234c597eadf
--- /dev/null
+++ b/clang/test/SemaHLSL/Language/SplatCasts-errors.hlsl
@@ -0,0 +1,32 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -verify
+
+struct S {
+// expected-note at -1 {{candidate constructor (the implicit copy constructor) not viable: no known conversion from 'int' to 'const S' for 1st argument}}
+// expected-note at -2 {{candidate constructor (the implicit move constructor) not viable: no known conversion from 'int' to 'S' for 1st argument}}
+// expected-note at -3 {{candidate constructor (the implicit default constructor) not viable: requires 0 arguments, but 1 was provided}}
+  int A : 8;
+  int B;
+};
+
+struct R {
+// expected-note at -1 {{candidate constructor (the implicit copy constructor) not viable: no known conversion from 'int' to 'const R' for 1st argument}}
+// expected-note at -2 {{candidate constructor (the implicit move constructor) not viable: no known conversion from 'int' to 'R' for 1st argument}}
+// expected-note at -3 {{candidate constructor (the implicit default constructor) not viable: requires 0 arguments, but 1 was provided}}
+  int A;
+  union {
+    float F;
+    int4 G;
+  };
+};
+
+// casting types which contain bitfields is not yet supported.
+export void cantCast() {
+  S s = (S)1;
+  // expected-error at -1 {{no matching conversion for C-style cast from 'int' to 'S'}}
+}
+
+// Can't cast a union
+export void cantCast2() {
+  R r = (R)1;
+  // expected-error at -1 {{no matching conversion for C-style cast from 'int' to 'R'}}
+}

>From 93c0450a22ad82a939f322e40d4c4e5a6b9d56dd Mon Sep 17 00:00:00 2001
From: Sarah Spall <sarahspall at microsoft.com>
Date: Mon, 10 Feb 2025 19:12:31 -0800
Subject: [PATCH 12/19] fix tests + associated issues in code

---
 clang/lib/Sema/SemaCast.cpp                   | 27 ++++++++++---------
 clang/lib/Sema/SemaHLSL.cpp                   |  4 ++-
 .../CodeGenHLSL/BasicFeatures/SplatCast.hlsl  |  8 +++---
 clang/test/SemaHLSL/Language/SplatCasts.hlsl  |  1 +
 4 files changed, 23 insertions(+), 17 deletions(-)

diff --git a/clang/lib/Sema/SemaCast.cpp b/clang/lib/Sema/SemaCast.cpp
index a60bc0687461f..2f5deba7e1258 100644
--- a/clang/lib/Sema/SemaCast.cpp
+++ b/clang/lib/Sema/SemaCast.cpp
@@ -2777,19 +2777,7 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
                                   ? CheckedConversionKind::FunctionalCast
                                   : CheckedConversionKind::CStyleCast;
 
-  // This case should not trigger on regular vector splat
   QualType SrcTy = SrcExpr.get()->getType();
-  if (Self.getLangOpts().HLSL &&
-      Self.HLSL().CanPerformSplatCast(SrcExpr.get(), DestType)) {
-    const VectorType *VT = SrcTy->getAs<VectorType>();
-    // change splat from vec1 case to splat from scalar
-    if (VT && VT->getNumElements() == 1)
-      SrcExpr = Self.ImpCastExprToType(SrcExpr.get(), VT->getElementType(),
-				       CK_HLSLVectorTruncation, VK_PRValue, nullptr, CCK);
-    Kind = CK_HLSLSplatCast;
-    return;
-  }
-
   // This case should not trigger on regular vector cast, vector truncation
   if (Self.getLangOpts().HLSL &&
       Self.HLSL().CanPerformElementwiseCast(SrcExpr.get(), DestType)) {
@@ -2801,6 +2789,21 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
     return;
   }
 
+  // This case should not trigger on regular vector splat
+  // If the relative order of this and the HLSLElementWise cast checks
+  // are changed, it might change which cast handles what in a few cases
+  if (Self.getLangOpts().HLSL &&
+      Self.HLSL().CanPerformSplatCast(SrcExpr.get(), DestType)) {
+    const VectorType *VT = SrcTy->getAs<VectorType>();
+    // change splat from vec1 case to splat from scalar
+    if (VT && VT->getNumElements() == 1)
+      SrcExpr = Self.ImpCastExprToType(SrcExpr.get(), VT->getElementType(),
+				       CK_HLSLVectorTruncation,
+				       SrcExpr.get()->getValueKind(), nullptr, CCK);
+    Kind = CK_HLSLSplatCast;
+    return;
+  }
+
   if (ValueKind == VK_PRValue && !DestType->isRecordType() &&
       !isPlaceholder(BuiltinType::Overload)) {
     SrcExpr = Self.DefaultFunctionArrayLvalueConversion(SrcExpr.get());
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 252121f88af05..9898732f30b27 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2810,7 +2810,9 @@ bool SemaHLSL::ContainsBitField(QualType BaseTy) {
 bool SemaHLSL::CanPerformSplatCast(Expr *Src, QualType DestTy) {
 
   QualType SrcTy = Src->getType();
-  if (SrcTy->isScalarType() && DestTy->isVectorType())
+  // Not a valid HLSL Splat cast if Dest is a scalar or if this is going to
+  // be a vector splat from a scalar.
+  if ((SrcTy->isScalarType() && DestTy->isVectorType()) || DestTy->isScalarType())
     return false;
 
   const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();
diff --git a/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl
index 2de68479179dd..0bc3e3fbd86cc 100644
--- a/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl
+++ b/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl
@@ -20,9 +20,9 @@ export void call4() {
 // CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
 // CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false)
 // CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
+// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i32 0
 // CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 0
 // CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 1
-// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0
 // CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
 // CHECK-NEXT: store i32 [[VL]], ptr [[G2]], align 4
 export void call8() {
@@ -37,7 +37,7 @@ export void call8() {
 // CHECK-NEXT: [[A:%.*]] = alloca <4 x i32>, align 16
 // CHECK-NEXT: store <1 x float> splat (float 1.000000e+00), ptr [[B]], align 4
 // CHECK-NEXT: [[L:%.*]] = load <1 x float>, ptr [[B]], align 4
-// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x float> [[L]], i64 0
+// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x float> [[L]], i32 0
 // CHECK-NEXT: [[C:%.*]] = fptosi float [[VL]] to i32
 // CHECK-NEXT: [[SI:%.*]] = insertelement <4 x i32> poison, i32 [[C]], i64 0
 // CHECK-NEXT: [[S:%.*]] = shufflevector <4 x i32> [[SI]], <4 x i32> poison, <4 x i32> zeroinitializer
@@ -58,9 +58,9 @@ struct S {
 // CHECK: [[s:%.*]] = alloca %struct.S, align 4
 // CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
 // CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
+// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i32 0
 // CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 0
 // CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 1
-// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0
 // CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
 // CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float
 // CHECK-NEXT: store float [[C]], ptr [[G2]], align 4
@@ -75,9 +75,9 @@ export void call3() {
 // CHECK-NEXT: [[s:%.*]] = alloca %struct.S, align 4
 // CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
 // CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
+// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i32 0
 // CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 0
 // CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 1
-// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0
 // CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
 // CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float
 // CHECK-NEXT: store float [[C]], ptr [[G2]], align 4
diff --git a/clang/test/SemaHLSL/Language/SplatCasts.hlsl b/clang/test/SemaHLSL/Language/SplatCasts.hlsl
index cfe3b981dc92c..c57a577e8929f 100644
--- a/clang/test/SemaHLSL/Language/SplatCasts.hlsl
+++ b/clang/test/SemaHLSL/Language/SplatCasts.hlsl
@@ -3,6 +3,7 @@
 // splat from vec1 to vec
 // CHECK-LABEL: call1
 // CHECK: CStyleCastExpr {{.*}} 'int3':'vector<int, 3>' <HLSLSplatCast>
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'float' lvalue <HLSLVectorTruncation> part_of_explicit_cast
 // CHECK-NEXT: DeclRefExpr {{.*}} 'float1':'vector<float, 1>' lvalue Var {{.*}} 'A' 'float1':'vector<float, 1>'
 export void call1() {
   float1 A = {1.0};

>From 1d317f21a280aa89b39cfe578fee3f022044e432 Mon Sep 17 00:00:00 2001
From: Sarah Spall <sarahspall at microsoft.com>
Date: Mon, 10 Feb 2025 19:16:00 -0800
Subject: [PATCH 13/19] clang format

---
 clang/lib/CodeGen/CGExprAgg.cpp    | 2 +-
 clang/lib/CodeGen/CGExprScalar.cpp | 3 ++-
 clang/lib/Sema/SemaCast.cpp        | 6 +++---
 clang/lib/Sema/SemaHLSL.cpp        | 3 ++-
 4 files changed, 8 insertions(+), 6 deletions(-)

diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index 36557ccb15f1a..69e77667648d0 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -996,7 +996,7 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
     Address DestVal = Dest.getAddress();
     SourceLocation Loc = E->getExprLoc();
 
-    assert (RV.isScalar() && "RHS of HLSL splat cast must be a scalar.");
+    assert(RV.isScalar() && "RHS of HLSL splat cast must be a scalar.");
     llvm::Value *SrcVal = RV.getScalarVal();
     EmitHLSLSplatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
     break;
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index b09c18f4a1229..cd7d9c243fcb2 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -2797,7 +2797,8 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
   }
   case CK_HLSLSplatCast: {
     // This cast should only handle splatting from vectors of length 1.
-    // But in Sema a cast should have been inserted to convert the vec1 to a scalar
+    // But in Sema a cast should have been inserted to convert the vec1 to a
+    // scalar
     assert(DestTy->isVectorType() && "Destination type must be a vector.");
     auto *DestVecTy = DestTy->getAs<VectorType>();
     QualType SrcTy = E->getType();
diff --git a/clang/lib/Sema/SemaCast.cpp b/clang/lib/Sema/SemaCast.cpp
index 2f5deba7e1258..e733dbc1c0d62 100644
--- a/clang/lib/Sema/SemaCast.cpp
+++ b/clang/lib/Sema/SemaCast.cpp
@@ -2797,9 +2797,9 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
     const VectorType *VT = SrcTy->getAs<VectorType>();
     // change splat from vec1 case to splat from scalar
     if (VT && VT->getNumElements() == 1)
-      SrcExpr = Self.ImpCastExprToType(SrcExpr.get(), VT->getElementType(),
-				       CK_HLSLVectorTruncation,
-				       SrcExpr.get()->getValueKind(), nullptr, CCK);
+      SrcExpr = Self.ImpCastExprToType(
+          SrcExpr.get(), VT->getElementType(), CK_HLSLVectorTruncation,
+          SrcExpr.get()->getValueKind(), nullptr, CCK);
     Kind = CK_HLSLSplatCast;
     return;
   }
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 9898732f30b27..b68a589e6de81 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2812,7 +2812,8 @@ bool SemaHLSL::CanPerformSplatCast(Expr *Src, QualType DestTy) {
   QualType SrcTy = Src->getType();
   // Not a valid HLSL Splat cast if Dest is a scalar or if this is going to
   // be a vector splat from a scalar.
-  if ((SrcTy->isScalarType() && DestTy->isVectorType()) || DestTy->isScalarType())
+  if ((SrcTy->isScalarType() && DestTy->isVectorType()) ||
+      DestTy->isScalarType())
     return false;
 
   const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();

>From 41f54fe9c0dc2f5e8ea9121949bbf879eb48f400 Mon Sep 17 00:00:00 2001
From: Sarah Spall <sarahspall at microsoft.com>
Date: Thu, 13 Feb 2025 09:42:21 -0800
Subject: [PATCH 14/19] rename cast from HLSLSplatCast to
 HLSLAggregateSplatCast

---
 clang/include/clang/AST/OperationKinds.def           |  2 +-
 clang/lib/AST/Expr.cpp                               |  2 +-
 clang/lib/AST/ExprConstant.cpp                       |  4 ++--
 clang/lib/CodeGen/CGExpr.cpp                         |  2 +-
 clang/lib/CodeGen/CGExprAgg.cpp                      | 12 ++++++------
 clang/lib/CodeGen/CGExprComplex.cpp                  |  2 +-
 clang/lib/CodeGen/CGExprConstant.cpp                 |  2 +-
 clang/lib/CodeGen/CGExprScalar.cpp                   |  6 +++---
 clang/lib/Edit/RewriteObjCFoundationAPI.cpp          |  2 +-
 clang/lib/Sema/Sema.cpp                              |  2 +-
 clang/lib/Sema/SemaCast.cpp                          |  2 +-
 clang/lib/Sema/SemaExpr.cpp                          |  3 ++-
 clang/lib/Sema/SemaHLSL.cpp                          |  2 +-
 clang/lib/Sema/SemaType.cpp                          |  6 ++++--
 clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp        |  2 +-
 .../{SplatCasts.hlsl => AggregateSplatCasts.hlsl}    |  4 ++--
 16 files changed, 29 insertions(+), 26 deletions(-)
 rename clang/test/SemaHLSL/Language/{SplatCasts.hlsl => AggregateSplatCasts.hlsl} (82%)

diff --git a/clang/include/clang/AST/OperationKinds.def b/clang/include/clang/AST/OperationKinds.def
index 333fc7e1b1882..790dd572a7c99 100644
--- a/clang/include/clang/AST/OperationKinds.def
+++ b/clang/include/clang/AST/OperationKinds.def
@@ -371,7 +371,7 @@ CAST_OPERATION(HLSLArrayRValue)
 CAST_OPERATION(HLSLElementwiseCast)
 
 // Splat cast for Aggregates (HLSL only).
-CAST_OPERATION(HLSLSplatCast)
+CAST_OPERATION(HLSLAggregateSplatCast)
 
 //===- Binary Operations  -------------------------------------------------===//
 // Operators listed in order of precedence.
diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp
index bbb475fbb30f2..4cba1a573f35e 100644
--- a/clang/lib/AST/Expr.cpp
+++ b/clang/lib/AST/Expr.cpp
@@ -1957,7 +1957,7 @@ bool CastExpr::CastConsistency() const {
   case CK_HLSLArrayRValue:
   case CK_HLSLVectorTruncation:
   case CK_HLSLElementwiseCast:
-  case CK_HLSLSplatCast:
+  case CK_HLSLAggregateSplatCast:
   CheckNoBasePath:
     assert(path_empty() && "Cast kind should not have a base path!");
     break;
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index ddc2d00883900..c3a562141b5b8 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -15029,7 +15029,7 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
   case CK_FixedPointCast:
   case CK_IntegralToFixedPoint:
   case CK_MatrixCast:
-  case CK_HLSLSplatCast:
+  case CK_HLSLAggregateSplatCast:
     llvm_unreachable("invalid cast kind for integral value");
 
   case CK_BitCast:
@@ -15908,7 +15908,7 @@ bool ComplexExprEvaluator::VisitCastExpr(const CastExpr *E) {
   case CK_MatrixCast:
   case CK_HLSLVectorTruncation:
   case CK_HLSLElementwiseCast:
-  case CK_HLSLSplatCast:
+  case CK_HLSLAggregateSplatCast:
     llvm_unreachable("invalid cast kind for complex value");
 
   case CK_LValueToRValue:
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 545d8b11a6a47..0b0ffd2db853f 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -5339,7 +5339,7 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) {
   case CK_HLSLVectorTruncation:
   case CK_HLSLArrayRValue:
   case CK_HLSLElementwiseCast:
-  case CK_HLSLSplatCast:
+  case CK_HLSLAggregateSplatCast:
     return EmitUnsupportedLValue(E, "unexpected cast lvalue");
 
   case CK_Dependent:
diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index 69e77667648d0..33182c030b6b9 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -491,9 +491,9 @@ static bool isTrivialFiller(Expr *E) {
   return false;
 }
 
-static void EmitHLSLSplatCast(CodeGenFunction &CGF, Address DestVal,
-                              QualType DestTy, llvm::Value *SrcVal,
-                              QualType SrcTy, SourceLocation Loc) {
+static void EmitHLSLAggregateSplatCast(CodeGenFunction &CGF, Address DestVal,
+                                       QualType DestTy, llvm::Value *SrcVal,
+                                       QualType SrcTy, SourceLocation Loc) {
   // Flatten our destination
   SmallVector<QualType> DestTypes; // Flattened type
   SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
@@ -988,7 +988,7 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
   case CK_HLSLArrayRValue:
     Visit(E->getSubExpr());
     break;
-  case CK_HLSLSplatCast: {
+  case CK_HLSLAggregateSplatCast: {
     Expr *Src = E->getSubExpr();
     QualType SrcTy = Src->getType();
     RValue RV = CGF.EmitAnyExpr(Src);
@@ -998,7 +998,7 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
 
     assert(RV.isScalar() && "RHS of HLSL splat cast must be a scalar.");
     llvm::Value *SrcVal = RV.getScalarVal();
-    EmitHLSLSplatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
+    EmitHLSLAggregateSplatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
     break;
   }
   case CK_HLSLElementwiseCast: {
@@ -1591,7 +1591,7 @@ static bool castPreservesZero(const CastExpr *CE) {
   case CK_AtomicToNonAtomic:
   case CK_HLSLVectorTruncation:
   case CK_HLSLElementwiseCast:
-  case CK_HLSLSplatCast:
+  case CK_HLSLAggregateSplatCast:
     return true;
 
   case CK_BaseToDerivedMemberPointer:
diff --git a/clang/lib/CodeGen/CGExprComplex.cpp b/clang/lib/CodeGen/CGExprComplex.cpp
index 3832b9b598b24..ff7c55be246cc 100644
--- a/clang/lib/CodeGen/CGExprComplex.cpp
+++ b/clang/lib/CodeGen/CGExprComplex.cpp
@@ -611,7 +611,7 @@ ComplexPairTy ComplexExprEmitter::EmitCast(CastKind CK, Expr *Op,
   case CK_HLSLVectorTruncation:
   case CK_HLSLArrayRValue:
   case CK_HLSLElementwiseCast:
-  case CK_HLSLSplatCast:
+  case CK_HLSLAggregateSplatCast:
     llvm_unreachable("invalid cast kind for complex value");
 
   case CK_FloatingRealToComplex:
diff --git a/clang/lib/CodeGen/CGExprConstant.cpp b/clang/lib/CodeGen/CGExprConstant.cpp
index b8ce83803b65f..ee5874b26f534 100644
--- a/clang/lib/CodeGen/CGExprConstant.cpp
+++ b/clang/lib/CodeGen/CGExprConstant.cpp
@@ -1336,7 +1336,7 @@ class ConstExprEmitter
     case CK_HLSLVectorTruncation:
     case CK_HLSLArrayRValue:
     case CK_HLSLElementwiseCast:
-    case CK_HLSLSplatCast:
+    case CK_HLSLAggregateSplatCast:
       return nullptr;
     }
     llvm_unreachable("Invalid CastKind");
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index cd7d9c243fcb2..912af7dd5b230 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -2795,7 +2795,7 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
     llvm::Value *Zero = llvm::Constant::getNullValue(CGF.SizeTy);
     return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc");
   }
-  case CK_HLSLSplatCast: {
+  case CK_HLSLAggregateSplatCast: {
     // This cast should only handle splatting from vectors of length 1.
     // But in Sema a cast should have been inserted to convert the vec1 to a
     // scalar
@@ -2804,7 +2804,7 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
     QualType SrcTy = E->getType();
     SourceLocation Loc = CE->getExprLoc();
     Value *V = Visit(const_cast<Expr *>(E));
-    assert(SrcTy->isBuiltinType() && "Invalid HLSL splat cast.");
+    assert(SrcTy->isBuiltinType() && "Invalid HLSL Aggregate splat cast.");
 
     Value *Cast =
         EmitScalarConversion(V, SrcTy, DestVecTy->getElementType(), Loc);
@@ -2816,7 +2816,7 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
     SourceLocation Loc = CE->getExprLoc();
     QualType SrcTy = E->getType();
 
-    assert(RV.isAggregate() && "Not a valid HLSL Flat Cast.");
+    assert(RV.isAggregate() && "Not a valid HLSL Elementwise Cast.");
     // RHS is an aggregate
     Address SrcVal = RV.getAggregateAddress();
     return EmitHLSLElementwiseCast(CGF, SrcVal, SrcTy, DestTy, Loc);
diff --git a/clang/lib/Edit/RewriteObjCFoundationAPI.cpp b/clang/lib/Edit/RewriteObjCFoundationAPI.cpp
index 10d3f62fcd0a4..627a1d6fb3dd5 100644
--- a/clang/lib/Edit/RewriteObjCFoundationAPI.cpp
+++ b/clang/lib/Edit/RewriteObjCFoundationAPI.cpp
@@ -1086,7 +1086,7 @@ static bool rewriteToNumericBoxedExpression(const ObjCMessageExpr *Msg,
 
     case CK_HLSLVectorTruncation:
     case CK_HLSLElementwiseCast:
-    case CK_HLSLSplatCast:
+    case CK_HLSLAggregateSplatCast:
       llvm_unreachable("HLSL-specific cast in Objective-C?");
       break;
 
diff --git a/clang/lib/Sema/Sema.cpp b/clang/lib/Sema/Sema.cpp
index 9eeefbb3c0023..afd1d7a4e36c1 100644
--- a/clang/lib/Sema/Sema.cpp
+++ b/clang/lib/Sema/Sema.cpp
@@ -709,7 +709,7 @@ ExprResult Sema::ImpCastExprToType(Expr *E, QualType Ty,
     case CK_ToVoid:
     case CK_NonAtomicToAtomic:
     case CK_HLSLArrayRValue:
-    case CK_HLSLSplatCast:
+    case CK_HLSLAggregateSplatCast:
       break;
     }
   }
diff --git a/clang/lib/Sema/SemaCast.cpp b/clang/lib/Sema/SemaCast.cpp
index e733dbc1c0d62..4a88ce84a1a08 100644
--- a/clang/lib/Sema/SemaCast.cpp
+++ b/clang/lib/Sema/SemaCast.cpp
@@ -2800,7 +2800,7 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
       SrcExpr = Self.ImpCastExprToType(
           SrcExpr.get(), VT->getElementType(), CK_HLSLVectorTruncation,
           SrcExpr.get()->getValueKind(), nullptr, CCK);
-    Kind = CK_HLSLSplatCast;
+    Kind = CK_HLSLAggregateSplatCast;
     return;
   }
 
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index 3cd4010740d19..ef06d52284ff8 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -20678,7 +20678,8 @@ ExprResult RebuildUnknownAnyExpr::VisitCallExpr(CallExpr *E) {
   const FunctionType *FnType = CalleeType->castAs<FunctionType>();
 
   // Verify that this is a legal result type of a function.
-  if (DestType->isArrayType() || DestType->isFunctionType()) {
+  if ((!S.getLangOpts().HLSL && DestType->isArrayType()) ||
+      DestType->isFunctionType()) {
     unsigned diagID = diag::err_func_returning_array_function;
     if (Kind == FK_BlockPointer)
       diagID = diag::err_block_returning_array_function;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index b68a589e6de81..4daf465482586 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2771,7 +2771,7 @@ bool SemaHLSL::CanPerformScalarCast(QualType SrcTy, QualType DestTy) {
 }
 
 // Detect if a type contains a bitfield. Will be removed when
-// bitfield support is added to HLSLElementwiseCast and HLSLSplatCast
+// bitfield support is added to HLSLElementwiseCast and HLSLAggregateSplatCast
 bool SemaHLSL::ContainsBitField(QualType BaseTy) {
   llvm::SmallVector<QualType, 16> WorkList;
   WorkList.push_back(BaseTy);
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index fd1b73de912e5..df95677457c1f 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -2527,7 +2527,7 @@ QualType Sema::BuildMatrixType(QualType ElementTy, Expr *NumRows, Expr *NumCols,
 }
 
 bool Sema::CheckFunctionReturnType(QualType T, SourceLocation Loc) {
-  if (T->isArrayType() || T->isFunctionType()) {
+  if ((!getLangOpts().HLSL && T->isArrayType()) || T->isFunctionType()) {
     Diag(Loc, diag::err_func_returning_array_function)
       << T->isFunctionType() << T;
     return true;
@@ -4931,7 +4931,9 @@ static TypeSourceInfo *GetFullTypeForDeclarator(TypeProcessingState &state,
 
       // C99 6.7.5.3p1: The return type may not be a function or array type.
       // For conversion functions, we'll diagnose this particular error later.
-      if (!D.isInvalidType() && (T->isArrayType() || T->isFunctionType()) &&
+      if (!D.isInvalidType() &&
+          ((!S.getLangOpts().HLSL && T->isArrayType()) ||
+           T->isFunctionType()) &&
           (D.getName().getKind() !=
            UnqualifiedIdKind::IK_ConversionFunctionId)) {
         unsigned diagID = diag::err_func_returning_array_function;
diff --git a/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp b/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp
index d75583f68eb6b..1061dafbb2473 100644
--- a/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp
+++ b/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp
@@ -523,7 +523,7 @@ void ExprEngine::VisitCast(const CastExpr *CastE, const Expr *Ex,
       case CK_MatrixCast:
       case CK_VectorSplat:
       case CK_HLSLElementwiseCast:
-      case CK_HLSLSplatCast:
+      case CK_HLSLAggregateSplatCast:
       case CK_HLSLVectorTruncation: {
         QualType resultType = CastE->getType();
         if (CastE->isGLValue())
diff --git a/clang/test/SemaHLSL/Language/SplatCasts.hlsl b/clang/test/SemaHLSL/Language/AggregateSplatCasts.hlsl
similarity index 82%
rename from clang/test/SemaHLSL/Language/SplatCasts.hlsl
rename to clang/test/SemaHLSL/Language/AggregateSplatCasts.hlsl
index c57a577e8929f..5a735de4ea346 100644
--- a/clang/test/SemaHLSL/Language/SplatCasts.hlsl
+++ b/clang/test/SemaHLSL/Language/AggregateSplatCasts.hlsl
@@ -2,7 +2,7 @@
 
 // splat from vec1 to vec
 // CHECK-LABEL: call1
-// CHECK: CStyleCastExpr {{.*}} 'int3':'vector<int, 3>' <HLSLSplatCast>
+// CHECK: CStyleCastExpr {{.*}} 'int3':'vector<int, 3>' <HLSLAggregateSplatCast>
 // CHECK-NEXT: ImplicitCastExpr {{.*}} 'float' lvalue <HLSLVectorTruncation> part_of_explicit_cast
 // CHECK-NEXT: DeclRefExpr {{.*}} 'float1':'vector<float, 1>' lvalue Var {{.*}} 'A' 'float1':'vector<float, 1>'
 export void call1() {
@@ -19,7 +19,7 @@ struct S {
 
 // splat from scalar to aggregate
 // CHECK-LABEL: call2
-// CHECK: CStyleCastExpr {{.*}} 'S' <HLSLSplatCast>
+// CHECK: CStyleCastExpr {{.*}} 'S' <HLSLAggregateSplatCast>
 // CHECK-NEXt: IntegerLiteral {{.*}} 'int' 5
 export void call2() {
   S s = (S)5; 

>From 7331af63c8516f54ac314835f1deb62c3b6c08c5 Mon Sep 17 00:00:00 2001
From: Sarah Spall <sarahspall at microsoft.com>
Date: Thu, 13 Feb 2025 11:03:02 -0800
Subject: [PATCH 15/19] address pr comments

---
 clang/lib/CodeGen/CGExprAgg.cpp                        |  4 ++--
 clang/lib/Sema/SemaHLSL.cpp                            | 10 +++++-----
 .../{SplatCast.hlsl => AggregateSplatCast.hlsl}        |  2 +-
 3 files changed, 8 insertions(+), 8 deletions(-)
 rename clang/test/CodeGenHLSL/BasicFeatures/{SplatCast.hlsl => AggregateSplatCast.hlsl} (99%)

diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index 33182c030b6b9..08a9c52a74e94 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -500,8 +500,8 @@ static void EmitHLSLAggregateSplatCast(CodeGenFunction &CGF, Address DestVal,
   // ^^ Flattened accesses to DestVal we want to store into
   CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes);
 
-  assert(SrcTy->isScalarType() && "Invalid HLSL splat cast.");
-  for (unsigned I = 0, Size = StoreGEPList.size(); I < Size; I++) {
+  assert(SrcTy->isScalarType() && "Invalid HLSL Aggregate splat cast.");
+  for (unsigned I = 0, Size = StoreGEPList.size(); I < Size; ++I) {
     llvm::Value *Cast =
         CGF.EmitScalarConversion(SrcVal, SrcTy, DestTypes[I], Loc);
 
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 4daf465482586..7ea141ed44eca 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2804,14 +2804,14 @@ bool SemaHLSL::ContainsBitField(QualType BaseTy) {
   return false;
 }
 
-// Can perform an HLSL splat cast if the Dest is an aggregate and the
+// Can perform an HLSL Aggregate splat cast if the Dest is an aggregate and the
 // Src is a scalar or a vector of length 1
 // Or if Dest is a vector and Src is a vector of length 1
-bool SemaHLSL::CanPerformSplatCast(Expr *Src, QualType DestTy) {
+bool SemaHLSL::CanPerformAggregateSplatCast(Expr *Src, QualType DestTy) {
 
   QualType SrcTy = Src->getType();
-  // Not a valid HLSL Splat cast if Dest is a scalar or if this is going to
-  // be a vector splat from a scalar.
+  // Not a valid HLSL Aggregate Splat cast if Dest is a scalar or if this is
+  // going to be a vector splat from a scalar.
   if ((SrcTy->isScalarType() && DestTy->isVectorType()) ||
       DestTy->isScalarType())
     return false;
@@ -2831,7 +2831,7 @@ bool SemaHLSL::CanPerformSplatCast(Expr *Src, QualType DestTy) {
   llvm::SmallVector<QualType> DestTypes;
   BuildFlattenedTypeList(DestTy, DestTypes);
 
-  for (unsigned I = 0, Size = DestTypes.size(); I < Size; I++) {
+  for (unsigned I = 0, Size = DestTypes.size(); I < Size; ++I) {
     if (DestTypes[I]->isUnionType())
       return false;
     if (!CanPerformScalarCast(SrcTy, DestTypes[I]))
diff --git a/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/AggregateSplatCast.hlsl
similarity index 99%
rename from clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl
rename to clang/test/CodeGenHLSL/BasicFeatures/AggregateSplatCast.hlsl
index 0bc3e3fbd86cc..42b6abec1b3d8 100644
--- a/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl
+++ b/clang/test/CodeGenHLSL/BasicFeatures/AggregateSplatCast.hlsl
@@ -52,7 +52,7 @@ struct S {
   float Y;
 };
 
-// struct splats?
+// struct splats
 // CHECK-LABEL: define void {{.*}}call3
 // CHECK: [[A:%.*]] = alloca <1 x i32>, align 4
 // CHECK: [[s:%.*]] = alloca %struct.S, align 4

>From 5fb0b9ffac2a11a863bb53842df627f3764c8106 Mon Sep 17 00:00:00 2001
From: Sarah Spall <sarahspall at microsoft.com>
Date: Thu, 13 Feb 2025 13:38:23 -0800
Subject: [PATCH 16/19] In vector case do more heavy lifting in sema so codegen
 can reuse VectorSplat codegen

---
 clang/include/clang/Sema/SemaHLSL.h           |  2 +-
 clang/lib/CodeGen/CGExprScalar.cpp            | 21 +++++--------------
 clang/lib/Sema/SemaCast.cpp                   |  9 +++++++-
 .../Language/AggregateSplatCasts.hlsl         |  1 +
 4 files changed, 15 insertions(+), 18 deletions(-)

diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index 3772301afdd4f..c9266ea50e4bf 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -144,7 +144,7 @@ class SemaHLSL : public SemaBase {
   bool CanPerformScalarCast(QualType SrcTy, QualType DestTy);
   bool ContainsBitField(QualType BaseTy);
   bool CanPerformElementwiseCast(Expr *Src, QualType DestType);
-  bool CanPerformSplatCast(Expr *Src, QualType DestType);
+  bool CanPerformAggregateSplatCast(Expr *Src, QualType DestType);
   ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg);
 
   QualType getInoutParameterType(QualType Ty);
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 912af7dd5b230..30f01496ba221 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -2643,6 +2643,11 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
     return EmitScalarConversion(Visit(E), E->getType(), DestTy,
                                 CE->getExprLoc());
   }
+    // CK_HLSLAggregateSplatCast only handles splatting to vectors from a vec1
+    // Casts were inserted in Sema to Cast the Src Expr to a Scalar and
+    // To perform any necessary Scalar Cast, so this Cast can be handled
+    // by the regular Vector Splat cast code.
+  case CK_HLSLAggregateSplatCast:
   case CK_VectorSplat: {
     llvm::Type *DstTy = ConvertType(DestTy);
     Value *Elt = Visit(const_cast<Expr *>(E));
@@ -2795,22 +2800,6 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
     llvm::Value *Zero = llvm::Constant::getNullValue(CGF.SizeTy);
     return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc");
   }
-  case CK_HLSLAggregateSplatCast: {
-    // This cast should only handle splatting from vectors of length 1.
-    // But in Sema a cast should have been inserted to convert the vec1 to a
-    // scalar
-    assert(DestTy->isVectorType() && "Destination type must be a vector.");
-    auto *DestVecTy = DestTy->getAs<VectorType>();
-    QualType SrcTy = E->getType();
-    SourceLocation Loc = CE->getExprLoc();
-    Value *V = Visit(const_cast<Expr *>(E));
-    assert(SrcTy->isBuiltinType() && "Invalid HLSL Aggregate splat cast.");
-
-    Value *Cast =
-        EmitScalarConversion(V, SrcTy, DestVecTy->getElementType(), Loc);
-    return Builder.CreateVectorSplat(DestVecTy->getNumElements(), Cast,
-                                     "splat");
-  }
   case CK_HLSLElementwiseCast: {
     RValue RV = CGF.EmitAnyExpr(E);
     SourceLocation Loc = CE->getExprLoc();
diff --git a/clang/lib/Sema/SemaCast.cpp b/clang/lib/Sema/SemaCast.cpp
index 4a88ce84a1a08..8972957ded9f5 100644
--- a/clang/lib/Sema/SemaCast.cpp
+++ b/clang/lib/Sema/SemaCast.cpp
@@ -2793,13 +2793,20 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
   // If the relative order of this and the HLSLElementWise cast checks
   // are changed, it might change which cast handles what in a few cases
   if (Self.getLangOpts().HLSL &&
-      Self.HLSL().CanPerformSplatCast(SrcExpr.get(), DestType)) {
+      Self.HLSL().CanPerformAggregateSplatCast(SrcExpr.get(), DestType)) {
     const VectorType *VT = SrcTy->getAs<VectorType>();
     // change splat from vec1 case to splat from scalar
     if (VT && VT->getNumElements() == 1)
       SrcExpr = Self.ImpCastExprToType(
           SrcExpr.get(), VT->getElementType(), CK_HLSLVectorTruncation,
           SrcExpr.get()->getValueKind(), nullptr, CCK);
+    // Inserting a scalar cast here allows for a simplified codegen in
+    // the case the destTy is a vector
+    if (const VectorType *DVT = DestType->getAs<VectorType>())
+      SrcExpr = Self.ImpCastExprToType(
+          SrcExpr.get(), DVT->getElementType(),
+          Self.PrepareScalarCast(SrcExpr, DVT->getElementType()),
+          SrcExpr.get()->getValueKind(), nullptr, CCK);
     Kind = CK_HLSLAggregateSplatCast;
     return;
   }
diff --git a/clang/test/SemaHLSL/Language/AggregateSplatCasts.hlsl b/clang/test/SemaHLSL/Language/AggregateSplatCasts.hlsl
index 5a735de4ea346..e5a851c50caf8 100644
--- a/clang/test/SemaHLSL/Language/AggregateSplatCasts.hlsl
+++ b/clang/test/SemaHLSL/Language/AggregateSplatCasts.hlsl
@@ -3,6 +3,7 @@
 // splat from vec1 to vec
 // CHECK-LABEL: call1
 // CHECK: CStyleCastExpr {{.*}} 'int3':'vector<int, 3>' <HLSLAggregateSplatCast>
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'int' lvalue <FloatingToIntegral> part_of_explicit_cast
 // CHECK-NEXT: ImplicitCastExpr {{.*}} 'float' lvalue <HLSLVectorTruncation> part_of_explicit_cast
 // CHECK-NEXT: DeclRefExpr {{.*}} 'float1':'vector<float, 1>' lvalue Var {{.*}} 'A' 'float1':'vector<float, 1>'
 export void call1() {

>From fdc868ea8455c8c833e4264d3b4cef040654f62e Mon Sep 17 00:00:00 2001
From: Sarah Spall <sarahspall at microsoft.com>
Date: Thu, 13 Feb 2025 14:19:01 -0800
Subject: [PATCH 17/19] make sure you can't cast intangible types

---
 clang/lib/Sema/SemaHLSL.cpp                   |  3 ++
 .../SemaHLSL/Language/SplatCasts-errors.hlsl  | 31 ++++++++++++++-----
 2 files changed, 27 insertions(+), 7 deletions(-)

diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 7ea141ed44eca..a12abc8616c23 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2710,6 +2710,9 @@ bool SemaHLSL::CheckCompatibleParameterABI(FunctionDecl *New,
 // clarity of what types are supported
 bool SemaHLSL::CanPerformScalarCast(QualType SrcTy, QualType DestTy) {
 
+  if (!SrcTy->isScalarType() || !DestTy->isScalarType())
+    return false;
+
   if (SemaRef.getASTContext().hasSameUnqualifiedType(SrcTy, DestTy))
     return true;
 
diff --git a/clang/test/SemaHLSL/Language/SplatCasts-errors.hlsl b/clang/test/SemaHLSL/Language/SplatCasts-errors.hlsl
index b0234c597eadf..662dae27e8200 100644
--- a/clang/test/SemaHLSL/Language/SplatCasts-errors.hlsl
+++ b/clang/test/SemaHLSL/Language/SplatCasts-errors.hlsl
@@ -1,17 +1,11 @@
-// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -verify
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -verify -verify-ignore-unexpected=note
 
 struct S {
-// expected-note at -1 {{candidate constructor (the implicit copy constructor) not viable: no known conversion from 'int' to 'const S' for 1st argument}}
-// expected-note at -2 {{candidate constructor (the implicit move constructor) not viable: no known conversion from 'int' to 'S' for 1st argument}}
-// expected-note at -3 {{candidate constructor (the implicit default constructor) not viable: requires 0 arguments, but 1 was provided}}
   int A : 8;
   int B;
 };
 
 struct R {
-// expected-note at -1 {{candidate constructor (the implicit copy constructor) not viable: no known conversion from 'int' to 'const R' for 1st argument}}
-// expected-note at -2 {{candidate constructor (the implicit move constructor) not viable: no known conversion from 'int' to 'R' for 1st argument}}
-// expected-note at -3 {{candidate constructor (the implicit default constructor) not viable: requires 0 arguments, but 1 was provided}}
   int A;
   union {
     float F;
@@ -30,3 +24,26 @@ export void cantCast2() {
   R r = (R)1;
   // expected-error at -1 {{no matching conversion for C-style cast from 'int' to 'R'}}
 }
+
+RWBuffer<float4> Buf;
+
+// Can't cast an intangible type
+export void cantCast3() {
+  Buf = (RWBuffer<float4>)1;
+  // expected-error at -1 {{no matching conversion for C-style cast from 'int' to 'RWBuffer<float4>' (aka 'RWBuffer<vector<float, 4>>')}}
+}
+
+export void cantCast4() {
+ RWBuffer<float4> B[2] = (RWBuffer<float4>[2])1;
+ // expected-error at -1 {{C-style cast from 'int' to 'RWBuffer<float4>[2]' (aka 'RWBuffer<vector<float, 4>>[2]') is not allowed}}
+}
+
+struct X {
+  int A;
+  RWBuffer<float4> Buf;
+};
+
+export void cantCast5() {
+  X x = (X)1;
+  // expected-error at -1 {{no matching conversion for C-style cast from 'int' to 'X'}}
+}
\ No newline at end of file

>From a56cfe973e6f9c41b591c51cd3022153287db882 Mon Sep 17 00:00:00 2001
From: Sarah Spall <sarahspall at microsoft.com>
Date: Thu, 13 Feb 2025 14:26:42 -0800
Subject: [PATCH 18/19] rename file

---
 .../{SplatCasts-errors.hlsl => AggregateSplatCast-errors.hlsl}    | 0
 1 file changed, 0 insertions(+), 0 deletions(-)
 rename clang/test/SemaHLSL/Language/{SplatCasts-errors.hlsl => AggregateSplatCast-errors.hlsl} (100%)

diff --git a/clang/test/SemaHLSL/Language/SplatCasts-errors.hlsl b/clang/test/SemaHLSL/Language/AggregateSplatCast-errors.hlsl
similarity index 100%
rename from clang/test/SemaHLSL/Language/SplatCasts-errors.hlsl
rename to clang/test/SemaHLSL/Language/AggregateSplatCast-errors.hlsl

>From e4a03e7b1bdad3fb8473ff018fad76fa59078a8e Mon Sep 17 00:00:00 2001
From: Sarah Spall <sarahspall at microsoft.com>
Date: Thu, 13 Feb 2025 15:25:12 -0800
Subject: [PATCH 19/19] revert code I did not intend to modify

---
 clang/lib/Sema/SemaExpr.cpp | 3 +--
 clang/lib/Sema/SemaType.cpp | 6 ++----
 2 files changed, 3 insertions(+), 6 deletions(-)

diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index ef06d52284ff8..3cd4010740d19 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -20678,8 +20678,7 @@ ExprResult RebuildUnknownAnyExpr::VisitCallExpr(CallExpr *E) {
   const FunctionType *FnType = CalleeType->castAs<FunctionType>();
 
   // Verify that this is a legal result type of a function.
-  if ((!S.getLangOpts().HLSL && DestType->isArrayType()) ||
-      DestType->isFunctionType()) {
+  if (DestType->isArrayType() || DestType->isFunctionType()) {
     unsigned diagID = diag::err_func_returning_array_function;
     if (Kind == FK_BlockPointer)
       diagID = diag::err_block_returning_array_function;
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index df95677457c1f..fd1b73de912e5 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -2527,7 +2527,7 @@ QualType Sema::BuildMatrixType(QualType ElementTy, Expr *NumRows, Expr *NumCols,
 }
 
 bool Sema::CheckFunctionReturnType(QualType T, SourceLocation Loc) {
-  if ((!getLangOpts().HLSL && T->isArrayType()) || T->isFunctionType()) {
+  if (T->isArrayType() || T->isFunctionType()) {
     Diag(Loc, diag::err_func_returning_array_function)
       << T->isFunctionType() << T;
     return true;
@@ -4931,9 +4931,7 @@ static TypeSourceInfo *GetFullTypeForDeclarator(TypeProcessingState &state,
 
       // C99 6.7.5.3p1: The return type may not be a function or array type.
       // For conversion functions, we'll diagnose this particular error later.
-      if (!D.isInvalidType() &&
-          ((!S.getLangOpts().HLSL && T->isArrayType()) ||
-           T->isFunctionType()) &&
+      if (!D.isInvalidType() && (T->isArrayType() || T->isFunctionType()) &&
           (D.getName().getKind() !=
            UnqualifiedIdKind::IK_ConversionFunctionId)) {
         unsigned diagID = diag::err_func_returning_array_function;



More information about the cfe-commits mailing list