[clang] [HLSL] Implement HLSL splatting (PR #118992)

Sarah Spall via cfe-commits cfe-commits at lists.llvm.org
Sat Feb 8 09:07:54 PST 2025


https://github.com/spall updated https://github.com/llvm/llvm-project/pull/118992

>From e994824f3630ee8b224afceb6c14d980c9013112 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Fri, 6 Dec 2024 05:14:17 +0000
Subject: [PATCH 1/9] splat cast wip

---
 clang/include/clang/AST/OperationKinds.def |  3 ++
 clang/include/clang/Sema/SemaHLSL.h        |  1 +
 clang/lib/CodeGen/CGExprAgg.cpp            | 42 ++++++++++++++++++++++
 clang/lib/CodeGen/CGExprScalar.cpp         | 16 +++++++++
 clang/lib/Sema/Sema.cpp                    |  1 +
 clang/lib/Sema/SemaCast.cpp                |  9 ++++-
 clang/lib/Sema/SemaHLSL.cpp                | 26 ++++++++++++++
 7 files changed, 97 insertions(+), 1 deletion(-)

diff --git a/clang/include/clang/AST/OperationKinds.def b/clang/include/clang/AST/OperationKinds.def
index b3dc7c3d8dc77e1..333fc7e1b18821e 100644
--- a/clang/include/clang/AST/OperationKinds.def
+++ b/clang/include/clang/AST/OperationKinds.def
@@ -370,6 +370,9 @@ CAST_OPERATION(HLSLArrayRValue)
 // Aggregate by Value cast (HLSL only).
 CAST_OPERATION(HLSLElementwiseCast)
 
+// Splat cast for Aggregates (HLSL only).
+CAST_OPERATION(HLSLSplatCast)
+
 //===- Binary Operations  -------------------------------------------------===//
 // Operators listed in order of precedence.
 // Note that additions to this should also update the StmtVisitor class,
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index 6e8ca2e4710dec8..7508b149b0d81d0 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -144,6 +144,7 @@ class SemaHLSL : public SemaBase {
   bool CanPerformScalarCast(QualType SrcTy, QualType DestTy);
   bool ContainsBitField(QualType BaseTy);
   bool CanPerformElementwiseCast(Expr *Src, QualType DestType);
+  bool CanPerformSplat(Expr *Src, QualType DestType);
   ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg);
 
   QualType getInoutParameterType(QualType Ty);
diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index c3f1cbed6b39f95..f26189bc4907cea 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -491,6 +491,33 @@ static bool isTrivialFiller(Expr *E) {
   return false;
 }
 
+static void EmitHLSLSplatCast(CodeGenFunction &CGF, Address DestVal,
+			      QualType DestTy, llvm::Value *SrcVal,
+			      QualType SrcTy, SourceLocation Loc) {
+  // Flatten our destination
+  SmallVector<QualType> DestTypes; // Flattened type
+  SmallVector<llvm::Value *, 4> IdxList;
+  SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
+  // ^^ Flattened accesses to DestVal we want to store into
+  CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList,
+		       DestTypes);
+
+  if (const VectorType *VT = SrcTy->getAs<VectorType>()) {
+    assert(VT->getNumElements() == 1 && "Invalid HLSL splat cast.");
+
+    SrcTy = VT->getElementType();
+    SrcVal = CGF.Builder.CreateExtractElement(SrcVal, (uint64_t)0,
+					      "vec.load");
+  }
+  assert(SrcTy->isScalarType() && "Invalid HLSL splat cast.");
+  for(unsigned i = 0; i < StoreGEPList.size(); i ++) {
+    llvm::Value *Cast = CGF.EmitScalarConversion(SrcVal, SrcTy,
+						 DestTypes[i],
+						 Loc);
+    CGF.PerformStore(StoreGEPList[i], Cast);
+  }
+}
+
 // emit a flat cast where the RHS is a scalar, including vector
 static void EmitHLSLScalarFlatCast(CodeGenFunction &CGF, Address DestVal,
                                    QualType DestTy, llvm::Value *SrcVal,
@@ -963,6 +990,21 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
   case CK_HLSLArrayRValue:
     Visit(E->getSubExpr());
     break;
+  case CK_HLSLSplatCast: {
+    Expr *Src = E->getSubExpr();
+    QualType SrcTy = Src->getType();
+    RValue RV = CGF.EmitAnyExpr(Src);
+    QualType DestTy = E->getType();
+    Address DestVal = Dest.getAddress();
+    SourceLocation Loc = E->getExprLoc();
+    
+    if (RV.isScalar()) {
+      llvm::Value *SrcVal = RV.getScalarVal();
+      EmitHLSLSplatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
+      break;
+    }
+    llvm_unreachable("RHS of HLSL splat cast must be a scalar or vector.");
+  }
   case CK_HLSLElementwiseCast: {
     Expr *Src = E->getSubExpr();
     QualType SrcTy = Src->getType();
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 80daed7e5395193..7dc2682bae42f2e 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -2795,6 +2795,22 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
     llvm::Value *Zero = llvm::Constant::getNullValue(CGF.SizeTy);
     return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc");
   }
+  case CK_HLSLSplatCast: {
+    assert(DestTy->isVectorType() && "Destination type must be a vector.");
+    auto *DestVecTy = DestTy->getAs<VectorType>();
+    QualType SrcTy = E->getType();
+    SourceLocation Loc = CE->getExprLoc();
+    Value *V = Visit(const_cast<Expr *>(E));
+    if (auto *VecTy = SrcTy->getAs<VectorType>()) {
+      assert(VecTy->getNumElements() == 1 && "Invalid HLSL splat cast.");
+      V = CGF.Builder.CreateExtractElement(V, (uint64_t)0, "vec.load");
+      SrcTy = VecTy->getElementType();
+    }
+    assert(SrcTy->isScalarType() && "Invalid HLSL splat cast.");
+    Value *Cast = EmitScalarConversion(V, SrcTy,
+				       DestVecTy->getElementType(), Loc);
+    return Builder.CreateVectorSplat(DestVecTy->getNumElements(), Cast, "splat");
+  }
   case CK_HLSLElementwiseCast: {
     RValue RV = CGF.EmitAnyExpr(E);
     SourceLocation Loc = CE->getExprLoc();
diff --git a/clang/lib/Sema/Sema.cpp b/clang/lib/Sema/Sema.cpp
index 15c18f9a4525b22..9eeefbb3c002329 100644
--- a/clang/lib/Sema/Sema.cpp
+++ b/clang/lib/Sema/Sema.cpp
@@ -709,6 +709,7 @@ ExprResult Sema::ImpCastExprToType(Expr *E, QualType Ty,
     case CK_ToVoid:
     case CK_NonAtomicToAtomic:
     case CK_HLSLArrayRValue:
+    case CK_HLSLSplatCast:
       break;
     }
   }
diff --git a/clang/lib/Sema/SemaCast.cpp b/clang/lib/Sema/SemaCast.cpp
index 23be71ad8e2aebc..56d8396b1e9d41a 100644
--- a/clang/lib/Sema/SemaCast.cpp
+++ b/clang/lib/Sema/SemaCast.cpp
@@ -2776,9 +2776,16 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
   CheckedConversionKind CCK = FunctionalStyle
                                   ? CheckedConversionKind::FunctionalCast
                                   : CheckedConversionKind::CStyleCast;
+
   // This case should not trigger on regular vector splat
-  // vector cast, vector truncation, or special hlsl splat cases
   QualType SrcTy = SrcExpr.get()->getType();
+  if (Self.getLangOpts().HLSL &&
+      Self.HLSL().CanPerformSplat(SrcExpr.get(), DestType)) {
+    Kind = CK_HLSLSplatCast;
+    return;
+  }
+
+  // This case should not trigger on regular vector cast, vector truncation
   if (Self.getLangOpts().HLSL &&
       Self.HLSL().CanPerformElementwiseCast(SrcExpr.get(), DestType)) {
     if (SrcTy->isConstantArrayType())
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index ec6b5b45de42bfa..7c9365787fd4fb5 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2804,6 +2804,32 @@ bool SemaHLSL::ContainsBitField(QualType BaseTy) {
   return false;
 }
 
+// Can perform an HLSL splat cast if the Dest is an aggregate and the
+// Src is a scalar or a vector of length 1
+// Or if Dest is a vector and Src is a vector of length 1
+bool SemaHLSL::CanPerformSplat(Expr *Src, QualType DestTy) {
+
+  QualType SrcTy = Src->getType();
+  if (SrcTy->isScalarType() && DestTy->isVectorType())
+    return false;
+
+  const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();
+  if (!(SrcTy->isScalarType() || (SrcVecTy && SrcVecTy->getNumElements() == 1)))
+    return false;
+
+  if (SrcVecTy)
+    SrcTy = SrcVecTy->getElementType();
+
+  llvm::SmallVector<QualType> DestTypes;
+  BuildFlattenedTypeList(DestTy, DestTypes);
+
+  for(unsigned i = 0; i < DestTypes.size(); i ++) {
+    if (!CanPerformScalarCast(SrcTy, DestTypes[i]))
+      return false;
+  }
+  return true;
+}
+
 // Can we perform an HLSL Elementwise cast?
 // TODO: update this code when matrices are added; see issue #88060
 bool SemaHLSL::CanPerformElementwiseCast(Expr *Src, QualType DestTy) {

>From 24bea86dd7a2c39ca9f21480990236dc44df8cf3 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Fri, 6 Dec 2024 05:19:00 +0000
Subject: [PATCH 2/9] make clang format happy

---
 clang/lib/CodeGen/CGExprAgg.cpp    | 19 ++++++++-----------
 clang/lib/CodeGen/CGExprScalar.cpp |  7 ++++---
 clang/lib/Sema/SemaHLSL.cpp        |  2 +-
 3 files changed, 13 insertions(+), 15 deletions(-)

diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index f26189bc4907cea..60beabf3a5fd0aa 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -492,28 +492,25 @@ static bool isTrivialFiller(Expr *E) {
 }
 
 static void EmitHLSLSplatCast(CodeGenFunction &CGF, Address DestVal,
-			      QualType DestTy, llvm::Value *SrcVal,
-			      QualType SrcTy, SourceLocation Loc) {
+                              QualType DestTy, llvm::Value *SrcVal,
+                              QualType SrcTy, SourceLocation Loc) {
   // Flatten our destination
   SmallVector<QualType> DestTypes; // Flattened type
   SmallVector<llvm::Value *, 4> IdxList;
   SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
   // ^^ Flattened accesses to DestVal we want to store into
-  CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList,
-		       DestTypes);
+  CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList, DestTypes);
 
   if (const VectorType *VT = SrcTy->getAs<VectorType>()) {
     assert(VT->getNumElements() == 1 && "Invalid HLSL splat cast.");
 
     SrcTy = VT->getElementType();
-    SrcVal = CGF.Builder.CreateExtractElement(SrcVal, (uint64_t)0,
-					      "vec.load");
+    SrcVal = CGF.Builder.CreateExtractElement(SrcVal, (uint64_t)0, "vec.load");
   }
   assert(SrcTy->isScalarType() && "Invalid HLSL splat cast.");
-  for(unsigned i = 0; i < StoreGEPList.size(); i ++) {
-    llvm::Value *Cast = CGF.EmitScalarConversion(SrcVal, SrcTy,
-						 DestTypes[i],
-						 Loc);
+  for (unsigned i = 0; i < StoreGEPList.size(); i++) {
+    llvm::Value *Cast =
+        CGF.EmitScalarConversion(SrcVal, SrcTy, DestTypes[i], Loc);
     CGF.PerformStore(StoreGEPList[i], Cast);
   }
 }
@@ -997,7 +994,7 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
     QualType DestTy = E->getType();
     Address DestVal = Dest.getAddress();
     SourceLocation Loc = E->getExprLoc();
-    
+
     if (RV.isScalar()) {
       llvm::Value *SrcVal = RV.getScalarVal();
       EmitHLSLSplatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 7dc2682bae42f2e..4a20b693b101fae 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -2807,9 +2807,10 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
       SrcTy = VecTy->getElementType();
     }
     assert(SrcTy->isScalarType() && "Invalid HLSL splat cast.");
-    Value *Cast = EmitScalarConversion(V, SrcTy,
-				       DestVecTy->getElementType(), Loc);
-    return Builder.CreateVectorSplat(DestVecTy->getNumElements(), Cast, "splat");
+    Value *Cast =
+        EmitScalarConversion(V, SrcTy, DestVecTy->getElementType(), Loc);
+    return Builder.CreateVectorSplat(DestVecTy->getNumElements(), Cast,
+                                     "splat");
   }
   case CK_HLSLElementwiseCast: {
     RValue RV = CGF.EmitAnyExpr(E);
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 7c9365787fd4fb5..024f778f8ffef5b 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2823,7 +2823,7 @@ bool SemaHLSL::CanPerformSplat(Expr *Src, QualType DestTy) {
   llvm::SmallVector<QualType> DestTypes;
   BuildFlattenedTypeList(DestTy, DestTypes);
 
-  for(unsigned i = 0; i < DestTypes.size(); i ++) {
+  for (unsigned i = 0; i < DestTypes.size(); i++) {
     if (!CanPerformScalarCast(SrcTy, DestTypes[i]))
       return false;
   }

>From 3575617d436f04eac4faadc17ead8bfe561e7e7c Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Fri, 6 Dec 2024 05:59:12 +0000
Subject: [PATCH 3/9] codegen test

---
 .../CodeGenHLSL/BasicFeatures/SplatCast.hlsl  | 87 +++++++++++++++++++
 1 file changed, 87 insertions(+)
 create mode 100644 clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl

diff --git a/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl
new file mode 100644
index 000000000000000..05359c1bce0ba35
--- /dev/null
+++ b/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl
@@ -0,0 +1,87 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s
+
+// array splat
+// CHECK-LABEL: define void {{.*}}call4
+// CHECK: [[B:%.*]] = alloca [2 x i32], align 4
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false)
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 1
+// CHECK-NEXT: store i32 3, ptr [[G1]], align 4
+// CHECK-NEXT: store i32 3, ptr [[G2]], align 4
+export void call4() {
+  int B[2] = {1,2};
+  B = (int[2])3;
+}
+
+// splat from vector of length 1
+// CHECK-LABEL: define void {{.*}}call8
+// CHECK: [[A:%.*]] = alloca <1 x i32>, align 4
+// CHECK-NEXT: [[B:%.*]] = alloca [2 x i32], align 4
+// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false)
+// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 1
+// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0
+// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
+// CHECK-NEXT: store i32 [[VL]], ptr [[G2]], align 4
+export void call8() {
+  int1 A = {1};
+  int B[2] = {1,2};
+  B = (int[2])A;
+}
+
+// vector splat from vector of length 1
+// CHECK-LABEL: define void {{.*}}call1
+// CHECK: [[B:%.*]] = alloca <1 x float>, align 4
+// CHECK-NEXT: [[A:%.*]] = alloca <4 x i32>, align 16
+// CHECK-NEXT: store <1 x float> splat (float 1.000000e+00), ptr [[B]], align 4
+// CHECK-NEXT: [[L:%.*]] = load <1 x float>, ptr [[B]], align 4
+// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x float> [[L]], i64 0
+// CHECK-NEXT: [[C:%.*]] = fptosi float [[VL]] to i32
+// CHECK-NEXT: [[SI:%.*]] = insertelement <4 x i32> poison, i32 [[C]], i64 0
+// CHECK-NEXT: [[S:%.*]] = shufflevector <4 x i32> [[SI]], <4 x i32> poison, <4 x i32> zeroinitializer
+// CHECK-NEXT: store <4 x i32> [[S]], ptr [[A]], align 16
+export void call1() {
+  float1 B = {1.0};
+  int4 A = (int4)B;
+}
+
+struct S {
+  int X;
+  float Y;
+};
+
+// struct splats?
+// CHECK-LABEL: define void {{.*}}call3
+// CHECK: [[A:%.*]] = alloca <1 x i32>, align 4
+// CHECK: [[s:%.*]] = alloca %struct.S, align 4
+// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
+// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 1
+// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0
+// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
+// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float
+// CHECK-NEXT: store float [[C]], ptr [[G2]], align 4
+export void call3() {
+  int1 A = {1};
+  S s = (S)A;
+}
+
+// struct splat from vector of length 1
+// CHECK-LABEL: define void {{.*}}call5
+// CHECK: [[A:%.*]] = alloca <1 x i32>, align 4
+// CHECK-NEXT: [[s:%.*]] = alloca %struct.S, align 4
+// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
+// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 1
+// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0
+// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
+// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float
+// CHECK-NEXT: store float [[C]], ptr [[G2]], align 4
+export void call5() {
+  int1 A = {1};
+  S s = (S)A;
+}

>From 288b8dac1c6fa4429c92c566a69da593c2ebb97c Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Fri, 6 Dec 2024 17:38:58 +0000
Subject: [PATCH 4/9] Try to handle Cast in all the places it needs to be
 handled

---
 clang/lib/AST/Expr.cpp                        | 1 +
 clang/lib/AST/ExprConstant.cpp                | 2 ++
 clang/lib/CodeGen/CGExprAgg.cpp               | 1 +
 clang/lib/CodeGen/CGExprComplex.cpp           | 1 +
 clang/lib/CodeGen/CGExprConstant.cpp          | 1 +
 clang/lib/Edit/RewriteObjCFoundationAPI.cpp   | 1 +
 clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp | 1 +
 7 files changed, 8 insertions(+)

diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp
index c22aa66ba2cfb3d..bbb475fbb30f269 100644
--- a/clang/lib/AST/Expr.cpp
+++ b/clang/lib/AST/Expr.cpp
@@ -1957,6 +1957,7 @@ bool CastExpr::CastConsistency() const {
   case CK_HLSLArrayRValue:
   case CK_HLSLVectorTruncation:
   case CK_HLSLElementwiseCast:
+  case CK_HLSLSplatCast:
   CheckNoBasePath:
     assert(path_empty() && "Cast kind should not have a base path!");
     break;
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 192b679b4c99596..ddc2d008839007e 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -15029,6 +15029,7 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
   case CK_FixedPointCast:
   case CK_IntegralToFixedPoint:
   case CK_MatrixCast:
+  case CK_HLSLSplatCast:
     llvm_unreachable("invalid cast kind for integral value");
 
   case CK_BitCast:
@@ -15907,6 +15908,7 @@ bool ComplexExprEvaluator::VisitCastExpr(const CastExpr *E) {
   case CK_MatrixCast:
   case CK_HLSLVectorTruncation:
   case CK_HLSLElementwiseCast:
+  case CK_HLSLSplatCast:
     llvm_unreachable("invalid cast kind for complex value");
 
   case CK_LValueToRValue:
diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index 60beabf3a5fd0aa..3584280e2fb9e44 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -1592,6 +1592,7 @@ static bool castPreservesZero(const CastExpr *CE) {
   case CK_AtomicToNonAtomic:
   case CK_HLSLVectorTruncation:
   case CK_HLSLElementwiseCast:
+    // TODO is this true for CK_HLSLSplatCast
     return true;
 
   case CK_BaseToDerivedMemberPointer:
diff --git a/clang/lib/CodeGen/CGExprComplex.cpp b/clang/lib/CodeGen/CGExprComplex.cpp
index c2679ea92dc9728..3832b9b598b24e9 100644
--- a/clang/lib/CodeGen/CGExprComplex.cpp
+++ b/clang/lib/CodeGen/CGExprComplex.cpp
@@ -611,6 +611,7 @@ ComplexPairTy ComplexExprEmitter::EmitCast(CastKind CK, Expr *Op,
   case CK_HLSLVectorTruncation:
   case CK_HLSLArrayRValue:
   case CK_HLSLElementwiseCast:
+  case CK_HLSLSplatCast:
     llvm_unreachable("invalid cast kind for complex value");
 
   case CK_FloatingRealToComplex:
diff --git a/clang/lib/CodeGen/CGExprConstant.cpp b/clang/lib/CodeGen/CGExprConstant.cpp
index ef11798869d3b13..b8ce83803b65fde 100644
--- a/clang/lib/CodeGen/CGExprConstant.cpp
+++ b/clang/lib/CodeGen/CGExprConstant.cpp
@@ -1336,6 +1336,7 @@ class ConstExprEmitter
     case CK_HLSLVectorTruncation:
     case CK_HLSLArrayRValue:
     case CK_HLSLElementwiseCast:
+    case CK_HLSLSplatCast:
       return nullptr;
     }
     llvm_unreachable("Invalid CastKind");
diff --git a/clang/lib/Edit/RewriteObjCFoundationAPI.cpp b/clang/lib/Edit/RewriteObjCFoundationAPI.cpp
index 32f5ebb55155ed1..10d3f62fcd0a416 100644
--- a/clang/lib/Edit/RewriteObjCFoundationAPI.cpp
+++ b/clang/lib/Edit/RewriteObjCFoundationAPI.cpp
@@ -1086,6 +1086,7 @@ static bool rewriteToNumericBoxedExpression(const ObjCMessageExpr *Msg,
 
     case CK_HLSLVectorTruncation:
     case CK_HLSLElementwiseCast:
+    case CK_HLSLSplatCast:
       llvm_unreachable("HLSL-specific cast in Objective-C?");
       break;
 
diff --git a/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp b/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp
index 3a983421358c7f4..d75583f68eb6b7b 100644
--- a/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp
+++ b/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp
@@ -523,6 +523,7 @@ void ExprEngine::VisitCast(const CastExpr *CastE, const Expr *Ex,
       case CK_MatrixCast:
       case CK_VectorSplat:
       case CK_HLSLElementwiseCast:
+      case CK_HLSLSplatCast:
       case CK_HLSLVectorTruncation: {
         QualType resultType = CastE->getType();
         if (CastE->isGLValue())

>From 0650840642960d950d64e234e9641e34096a6c55 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Wed, 11 Dec 2024 20:54:39 +0000
Subject: [PATCH 5/9] get code compiling after rebase

---
 clang/lib/CodeGen/CGExprAgg.cpp | 13 ++++++++++---
 1 file changed, 10 insertions(+), 3 deletions(-)

diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index 3584280e2fb9e44..3330cd03628f75e 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -496,10 +496,9 @@ static void EmitHLSLSplatCast(CodeGenFunction &CGF, Address DestVal,
                               QualType SrcTy, SourceLocation Loc) {
   // Flatten our destination
   SmallVector<QualType> DestTypes; // Flattened type
-  SmallVector<llvm::Value *, 4> IdxList;
   SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
   // ^^ Flattened accesses to DestVal we want to store into
-  CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList, DestTypes);
+  CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes);
 
   if (const VectorType *VT = SrcTy->getAs<VectorType>()) {
     assert(VT->getNumElements() == 1 && "Invalid HLSL splat cast.");
@@ -511,7 +510,15 @@ static void EmitHLSLSplatCast(CodeGenFunction &CGF, Address DestVal,
   for (unsigned i = 0; i < StoreGEPList.size(); i++) {
     llvm::Value *Cast =
         CGF.EmitScalarConversion(SrcVal, SrcTy, DestTypes[i], Loc);
-    CGF.PerformStore(StoreGEPList[i], Cast);
+
+    // store back
+    llvm::Value *Idx = StoreGEPList[i].second;
+    if (Idx) {
+      llvm::Value *V =
+          CGF.Builder.CreateLoad(StoreGEPList[i].first, "load.for.insert");
+      Cast = CGF.Builder.CreateInsertElement(V, Cast, Idx);
+    }
+    CGF.Builder.CreateStore(Cast, StoreGEPList[i].first);
   }
 }
 

>From f924b13ada0c3344f3cc4f87a859f0ecd16705cb Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Thu, 12 Dec 2024 00:04:29 +0000
Subject: [PATCH 6/9] Self review

---
 clang/lib/CodeGen/CGExprScalar.cpp           | 15 +++++++-----
 clang/lib/Sema/SemaHLSL.cpp                  |  7 +++---
 clang/test/SemaHLSL/Language/SplatCasts.hlsl | 25 ++++++++++++++++++++
 3 files changed, 38 insertions(+), 9 deletions(-)
 create mode 100644 clang/test/SemaHLSL/Language/SplatCasts.hlsl

diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 4a20b693b101fae..85c0265ea14b611 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -2796,17 +2796,20 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
     return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc");
   }
   case CK_HLSLSplatCast: {
+    // This code should only handle splatting from vectors of length 1.
     assert(DestTy->isVectorType() && "Destination type must be a vector.");
     auto *DestVecTy = DestTy->getAs<VectorType>();
     QualType SrcTy = E->getType();
     SourceLocation Loc = CE->getExprLoc();
     Value *V = Visit(const_cast<Expr *>(E));
-    if (auto *VecTy = SrcTy->getAs<VectorType>()) {
-      assert(VecTy->getNumElements() == 1 && "Invalid HLSL splat cast.");
-      V = CGF.Builder.CreateExtractElement(V, (uint64_t)0, "vec.load");
-      SrcTy = VecTy->getElementType();
-    }
-    assert(SrcTy->isScalarType() && "Invalid HLSL splat cast.");
+    assert(SrcTy->isVectorType() && "Invalid HLSL splat cast.");
+
+    auto *VecTy = SrcTy->getAs<VectorType>();
+    assert(VecTy->getNumElements() == 1 && "Invalid HLSL splat cast.");
+
+    V = CGF.Builder.CreateExtractElement(V, (uint64_t)0, "vec.load");
+    SrcTy = VecTy->getElementType();
+
     Value *Cast =
         EmitScalarConversion(V, SrcTy, DestVecTy->getElementType(), Loc);
     return Builder.CreateVectorSplat(DestVecTy->getNumElements(), Cast,
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 024f778f8ffef5b..432a42016789ec2 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2814,12 +2814,13 @@ bool SemaHLSL::CanPerformSplat(Expr *Src, QualType DestTy) {
     return false;
 
   const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();
-  if (!(SrcTy->isScalarType() || (SrcVecTy && SrcVecTy->getNumElements() == 1)))
-    return false;
-
   if (SrcVecTy)
     SrcTy = SrcVecTy->getElementType();
 
+  // Src isn't a scalar or a vector of length 1
+  if (!SrcTy->isScalarType() && !(SrcVecTy && SrcVecTy->getNumElements() == 1))
+    return false;
+
   llvm::SmallVector<QualType> DestTypes;
   BuildFlattenedTypeList(DestTy, DestTypes);
 
diff --git a/clang/test/SemaHLSL/Language/SplatCasts.hlsl b/clang/test/SemaHLSL/Language/SplatCasts.hlsl
new file mode 100644
index 000000000000000..593a8e67fd4a3b8
--- /dev/null
+++ b/clang/test/SemaHLSL/Language/SplatCasts.hlsl
@@ -0,0 +1,25 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -finclude-default-header -fnative-half-type %s -ast-dump | FileCheck %s
+
+// splat from vec1 to vec
+// CHECK-LABEL: call1
+// CHECK: CStyleCastExpr {{.*}} 'int3':'vector<int, 3>' <HLSLSplatCast>
+// CHECK-NEXT: DeclRefExpr {{.*}} 'float1':'vector<float, 1>' lvalue Var {{.*}} 'A' 'float1':'vector<float, 1>'
+export void call1() {
+  float1 A = {1.0};
+  int3 B = (int3)A;
+}
+
+struct S {
+  int A;
+  float B;
+  int C;
+  float D;
+};
+
+// splat from scalar to aggregate
+// CHECK-LABEL: call2
+// CHECK: CStyleCastExpr {{.*}} 'S' <HLSLSplatCast>
+// CHECK-NEXt: IntegerLiteral {{.*}} 'int' 5
+export void call2() {
+  S s = (S)5; 
+}
\ No newline at end of file

>From 89ceeb7d6b445f10fa6b7deb8c10267cd292da7b Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Thu, 12 Dec 2024 05:59:55 +0000
Subject: [PATCH 7/9] move code back that broke tests

---
 clang/lib/Sema/SemaHLSL.cpp | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 432a42016789ec2..76ca24b10c60a16 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2814,13 +2814,14 @@ bool SemaHLSL::CanPerformSplat(Expr *Src, QualType DestTy) {
     return false;
 
   const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();
-  if (SrcVecTy)
-    SrcTy = SrcVecTy->getElementType();
 
   // Src isn't a scalar or a vector of length 1
   if (!SrcTy->isScalarType() && !(SrcVecTy && SrcVecTy->getNumElements() == 1))
     return false;
 
+  if (SrcVecTy)
+    SrcTy = SrcVecTy->getElementType();
+
   llvm::SmallVector<QualType> DestTypes;
   BuildFlattenedTypeList(DestTy, DestTypes);
 

>From 7f5b3e4f39f2a4cf2d42e5281e70d900878c1a3b Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Thu, 12 Dec 2024 06:08:46 +0000
Subject: [PATCH 8/9] fix tests

---
 .../CodeGenHLSL/BasicFeatures/SplatCast.hlsl     | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl
index 05359c1bce0ba35..2de68479179dd4c 100644
--- a/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl
+++ b/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl
@@ -4,8 +4,8 @@
 // CHECK-LABEL: define void {{.*}}call4
 // CHECK: [[B:%.*]] = alloca [2 x i32], align 4
 // CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false)
-// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0
-// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 1
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 1
 // CHECK-NEXT: store i32 3, ptr [[G1]], align 4
 // CHECK-NEXT: store i32 3, ptr [[G2]], align 4
 export void call4() {
@@ -20,8 +20,8 @@ export void call4() {
 // CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
 // CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false)
 // CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
-// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0
-// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 1
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 1
 // CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0
 // CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
 // CHECK-NEXT: store i32 [[VL]], ptr [[G2]], align 4
@@ -58,8 +58,8 @@ struct S {
 // CHECK: [[s:%.*]] = alloca %struct.S, align 4
 // CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
 // CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
-// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0
-// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 1
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 1
 // CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0
 // CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
 // CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float
@@ -75,8 +75,8 @@ export void call3() {
 // CHECK-NEXT: [[s:%.*]] = alloca %struct.S, align 4
 // CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
 // CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
-// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0
-// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 1
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 1
 // CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0
 // CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
 // CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float

>From 844ba82eb5dcfbd0105db2d4943266fa8d009c17 Mon Sep 17 00:00:00 2001
From: Sarah Spall <sarahspall at microsoft.com>
Date: Sat, 8 Feb 2025 09:07:05 -0800
Subject: [PATCH 9/9] add cast to cases

---
 clang/lib/CodeGen/CGExpr.cpp    | 1 +
 clang/lib/CodeGen/CGExprAgg.cpp | 2 +-
 2 files changed, 2 insertions(+), 1 deletion(-)

diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 2bbc0791c65876f..545d8b11a6a47a9 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -5339,6 +5339,7 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) {
   case CK_HLSLVectorTruncation:
   case CK_HLSLArrayRValue:
   case CK_HLSLElementwiseCast:
+  case CK_HLSLSplatCast:
     return EmitUnsupportedLValue(E, "unexpected cast lvalue");
 
   case CK_Dependent:
diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index 3330cd03628f75e..b7fe62687b074a0 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -1599,7 +1599,7 @@ static bool castPreservesZero(const CastExpr *CE) {
   case CK_AtomicToNonAtomic:
   case CK_HLSLVectorTruncation:
   case CK_HLSLElementwiseCast:
-    // TODO is this true for CK_HLSLSplatCast
+  case CK_HLSLSplatCast:
     return true;
 
   case CK_BaseToDerivedMemberPointer:



More information about the cfe-commits mailing list