[clang] 4d2d0af - [HLSL] Implement HLSL Aggregate splatting (#118992)

via cfe-commits cfe-commits at lists.llvm.org
Fri Feb 14 09:25:28 PST 2025


Author: Sarah Spall
Date: 2025-02-14T09:25:24-08:00
New Revision: 4d2d0afceeb732a5238c2167ab7a6b88cc66d976

URL: https://github.com/llvm/llvm-project/commit/4d2d0afceeb732a5238c2167ab7a6b88cc66d976
DIFF: https://github.com/llvm/llvm-project/commit/4d2d0afceeb732a5238c2167ab7a6b88cc66d976.diff

LOG: [HLSL] Implement HLSL Aggregate splatting (#118992)

Implement HLSL Aggregate Splat casting that handles splatting for arrays
and structs, and vectors if splatting from a vec1.
Closes #100609 and Closes #100619 
Depends on #118842

Added: 
    clang/test/CodeGenHLSL/BasicFeatures/AggregateSplatCast.hlsl
    clang/test/SemaHLSL/Language/AggregateSplatCast-errors.hlsl
    clang/test/SemaHLSL/Language/AggregateSplatCasts.hlsl

Modified: 
    clang/include/clang/AST/OperationKinds.def
    clang/include/clang/Sema/SemaHLSL.h
    clang/lib/AST/Expr.cpp
    clang/lib/AST/ExprConstant.cpp
    clang/lib/CodeGen/CGExpr.cpp
    clang/lib/CodeGen/CGExprAgg.cpp
    clang/lib/CodeGen/CGExprComplex.cpp
    clang/lib/CodeGen/CGExprConstant.cpp
    clang/lib/CodeGen/CGExprScalar.cpp
    clang/lib/Edit/RewriteObjCFoundationAPI.cpp
    clang/lib/Sema/Sema.cpp
    clang/lib/Sema/SemaCast.cpp
    clang/lib/Sema/SemaHLSL.cpp
    clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp
    clang/test/SemaHLSL/Language/ElementwiseCast-errors.hlsl

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/AST/OperationKinds.def b/clang/include/clang/AST/OperationKinds.def
index b3dc7c3d8dc77..790dd572a7c99 100644
--- a/clang/include/clang/AST/OperationKinds.def
+++ b/clang/include/clang/AST/OperationKinds.def
@@ -370,6 +370,9 @@ CAST_OPERATION(HLSLArrayRValue)
 // Aggregate by Value cast (HLSL only).
 CAST_OPERATION(HLSLElementwiseCast)
 
+// Splat cast for Aggregates (HLSL only).
+CAST_OPERATION(HLSLAggregateSplatCast)
+
 //===- Binary Operations  -------------------------------------------------===//
 // Operators listed in order of precedence.
 // Note that additions to this should also update the StmtVisitor class,

diff  --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index 6e8ca2e4710de..c9266ea50e4bf 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -144,6 +144,7 @@ class SemaHLSL : public SemaBase {
   bool CanPerformScalarCast(QualType SrcTy, QualType DestTy);
   bool ContainsBitField(QualType BaseTy);
   bool CanPerformElementwiseCast(Expr *Src, QualType DestType);
+  bool CanPerformAggregateSplatCast(Expr *Src, QualType DestType);
   ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg);
 
   QualType getInoutParameterType(QualType Ty);

diff  --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp
index 6978fd837cbd0..460167c1b9a3d 100644
--- a/clang/lib/AST/Expr.cpp
+++ b/clang/lib/AST/Expr.cpp
@@ -1968,6 +1968,7 @@ bool CastExpr::CastConsistency() const {
   case CK_HLSLArrayRValue:
   case CK_HLSLVectorTruncation:
   case CK_HLSLElementwiseCast:
+  case CK_HLSLAggregateSplatCast:
   CheckNoBasePath:
     assert(path_empty() && "Cast kind should not have a base path!");
     break;

diff  --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 5c6ca4c9ee4de..043974fb41443 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -15025,6 +15025,7 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
   case CK_FixedPointCast:
   case CK_IntegralToFixedPoint:
   case CK_MatrixCast:
+  case CK_HLSLAggregateSplatCast:
     llvm_unreachable("invalid cast kind for integral value");
 
   case CK_BitCast:
@@ -15903,6 +15904,7 @@ bool ComplexExprEvaluator::VisitCastExpr(const CastExpr *E) {
   case CK_MatrixCast:
   case CK_HLSLVectorTruncation:
   case CK_HLSLElementwiseCast:
+  case CK_HLSLAggregateSplatCast:
     llvm_unreachable("invalid cast kind for complex value");
 
   case CK_LValueToRValue:

diff  --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 2bbc0791c6587..0b0ffd2db853f 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -5339,6 +5339,7 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) {
   case CK_HLSLVectorTruncation:
   case CK_HLSLArrayRValue:
   case CK_HLSLElementwiseCast:
+  case CK_HLSLAggregateSplatCast:
     return EmitUnsupportedLValue(E, "unexpected cast lvalue");
 
   case CK_Dependent:

diff  --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index c574827ca0944..d25d0f2c2133c 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -498,6 +498,31 @@ static bool isTrivialFiller(Expr *E) {
   return false;
 }
 
+static void EmitHLSLAggregateSplatCast(CodeGenFunction &CGF, Address DestVal,
+                                       QualType DestTy, llvm::Value *SrcVal,
+                                       QualType SrcTy, SourceLocation Loc) {
+  // Flatten our destination
+  SmallVector<QualType> DestTypes; // Flattened type
+  SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
+  // ^^ Flattened accesses to DestVal we want to store into
+  CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes);
+
+  assert(SrcTy->isScalarType() && "Invalid HLSL Aggregate splat cast.");
+  for (unsigned I = 0, Size = StoreGEPList.size(); I < Size; ++I) {
+    llvm::Value *Cast =
+        CGF.EmitScalarConversion(SrcVal, SrcTy, DestTypes[I], Loc);
+
+    // store back
+    llvm::Value *Idx = StoreGEPList[I].second;
+    if (Idx) {
+      llvm::Value *V =
+          CGF.Builder.CreateLoad(StoreGEPList[I].first, "load.for.insert");
+      Cast = CGF.Builder.CreateInsertElement(V, Cast, Idx);
+    }
+    CGF.Builder.CreateStore(Cast, StoreGEPList[I].first);
+  }
+}
+
 // emit a flat cast where the RHS is a scalar, including vector
 static void EmitHLSLScalarFlatCast(CodeGenFunction &CGF, Address DestVal,
                                    QualType DestTy, llvm::Value *SrcVal,
@@ -970,6 +995,19 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
   case CK_HLSLArrayRValue:
     Visit(E->getSubExpr());
     break;
+  case CK_HLSLAggregateSplatCast: {
+    Expr *Src = E->getSubExpr();
+    QualType SrcTy = Src->getType();
+    RValue RV = CGF.EmitAnyExpr(Src);
+    QualType DestTy = E->getType();
+    Address DestVal = Dest.getAddress();
+    SourceLocation Loc = E->getExprLoc();
+
+    assert(RV.isScalar() && "RHS of HLSL splat cast must be a scalar.");
+    llvm::Value *SrcVal = RV.getScalarVal();
+    EmitHLSLAggregateSplatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
+    break;
+  }
   case CK_HLSLElementwiseCast: {
     Expr *Src = E->getSubExpr();
     QualType SrcTy = Src->getType();
@@ -1560,6 +1598,7 @@ static bool castPreservesZero(const CastExpr *CE) {
   case CK_AtomicToNonAtomic:
   case CK_HLSLVectorTruncation:
   case CK_HLSLElementwiseCast:
+  case CK_HLSLAggregateSplatCast:
     return true;
 
   case CK_BaseToDerivedMemberPointer:

diff  --git a/clang/lib/CodeGen/CGExprComplex.cpp b/clang/lib/CodeGen/CGExprComplex.cpp
index c2679ea92dc97..ff7c55be246cc 100644
--- a/clang/lib/CodeGen/CGExprComplex.cpp
+++ b/clang/lib/CodeGen/CGExprComplex.cpp
@@ -611,6 +611,7 @@ ComplexPairTy ComplexExprEmitter::EmitCast(CastKind CK, Expr *Op,
   case CK_HLSLVectorTruncation:
   case CK_HLSLArrayRValue:
   case CK_HLSLElementwiseCast:
+  case CK_HLSLAggregateSplatCast:
     llvm_unreachable("invalid cast kind for complex value");
 
   case CK_FloatingRealToComplex:

diff  --git a/clang/lib/CodeGen/CGExprConstant.cpp b/clang/lib/CodeGen/CGExprConstant.cpp
index ef11798869d3b..ee5874b26f534 100644
--- a/clang/lib/CodeGen/CGExprConstant.cpp
+++ b/clang/lib/CodeGen/CGExprConstant.cpp
@@ -1336,6 +1336,7 @@ class ConstExprEmitter
     case CK_HLSLVectorTruncation:
     case CK_HLSLArrayRValue:
     case CK_HLSLElementwiseCast:
+    case CK_HLSLAggregateSplatCast:
       return nullptr;
     }
     llvm_unreachable("Invalid CastKind");

diff  --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 80daed7e53951..30f01496ba221 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -2643,6 +2643,11 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
     return EmitScalarConversion(Visit(E), E->getType(), DestTy,
                                 CE->getExprLoc());
   }
+    // CK_HLSLAggregateSplatCast only handles splatting to vectors from a vec1
+    // Casts were inserted in Sema to Cast the Src Expr to a Scalar and
+    // To perform any necessary Scalar Cast, so this Cast can be handled
+    // by the regular Vector Splat cast code.
+  case CK_HLSLAggregateSplatCast:
   case CK_VectorSplat: {
     llvm::Type *DstTy = ConvertType(DestTy);
     Value *Elt = Visit(const_cast<Expr *>(E));
@@ -2800,7 +2805,7 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
     SourceLocation Loc = CE->getExprLoc();
     QualType SrcTy = E->getType();
 
-    assert(RV.isAggregate() && "Not a valid HLSL Flat Cast.");
+    assert(RV.isAggregate() && "Not a valid HLSL Elementwise Cast.");
     // RHS is an aggregate
     Address SrcVal = RV.getAggregateAddress();
     return EmitHLSLElementwiseCast(CGF, SrcVal, SrcTy, DestTy, Loc);

diff  --git a/clang/lib/Edit/RewriteObjCFoundationAPI.cpp b/clang/lib/Edit/RewriteObjCFoundationAPI.cpp
index 32f5ebb55155e..627a1d6fb3dd5 100644
--- a/clang/lib/Edit/RewriteObjCFoundationAPI.cpp
+++ b/clang/lib/Edit/RewriteObjCFoundationAPI.cpp
@@ -1086,6 +1086,7 @@ static bool rewriteToNumericBoxedExpression(const ObjCMessageExpr *Msg,
 
     case CK_HLSLVectorTruncation:
     case CK_HLSLElementwiseCast:
+    case CK_HLSLAggregateSplatCast:
       llvm_unreachable("HLSL-specific cast in Objective-C?");
       break;
 

diff  --git a/clang/lib/Sema/Sema.cpp b/clang/lib/Sema/Sema.cpp
index 15c18f9a4525b..afd1d7a4e36c1 100644
--- a/clang/lib/Sema/Sema.cpp
+++ b/clang/lib/Sema/Sema.cpp
@@ -709,6 +709,7 @@ ExprResult Sema::ImpCastExprToType(Expr *E, QualType Ty,
     case CK_ToVoid:
     case CK_NonAtomicToAtomic:
     case CK_HLSLArrayRValue:
+    case CK_HLSLAggregateSplatCast:
       break;
     }
   }

diff  --git a/clang/lib/Sema/SemaCast.cpp b/clang/lib/Sema/SemaCast.cpp
index 23be71ad8e2ae..8972957ded9f5 100644
--- a/clang/lib/Sema/SemaCast.cpp
+++ b/clang/lib/Sema/SemaCast.cpp
@@ -2776,9 +2776,9 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
   CheckedConversionKind CCK = FunctionalStyle
                                   ? CheckedConversionKind::FunctionalCast
                                   : CheckedConversionKind::CStyleCast;
-  // This case should not trigger on regular vector splat
-  // vector cast, vector truncation, or special hlsl splat cases
+
   QualType SrcTy = SrcExpr.get()->getType();
+  // This case should not trigger on regular vector cast, vector truncation
   if (Self.getLangOpts().HLSL &&
       Self.HLSL().CanPerformElementwiseCast(SrcExpr.get(), DestType)) {
     if (SrcTy->isConstantArrayType())
@@ -2789,6 +2789,28 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
     return;
   }
 
+  // This case should not trigger on regular vector splat
+  // If the relative order of this and the HLSLElementWise cast checks
+  // are changed, it might change which cast handles what in a few cases
+  if (Self.getLangOpts().HLSL &&
+      Self.HLSL().CanPerformAggregateSplatCast(SrcExpr.get(), DestType)) {
+    const VectorType *VT = SrcTy->getAs<VectorType>();
+    // change splat from vec1 case to splat from scalar
+    if (VT && VT->getNumElements() == 1)
+      SrcExpr = Self.ImpCastExprToType(
+          SrcExpr.get(), VT->getElementType(), CK_HLSLVectorTruncation,
+          SrcExpr.get()->getValueKind(), nullptr, CCK);
+    // Inserting a scalar cast here allows for a simplified codegen in
+    // the case the destTy is a vector
+    if (const VectorType *DVT = DestType->getAs<VectorType>())
+      SrcExpr = Self.ImpCastExprToType(
+          SrcExpr.get(), DVT->getElementType(),
+          Self.PrepareScalarCast(SrcExpr, DVT->getElementType()),
+          SrcExpr.get()->getValueKind(), nullptr, CCK);
+    Kind = CK_HLSLAggregateSplatCast;
+    return;
+  }
+
   if (ValueKind == VK_PRValue && !DestType->isRecordType() &&
       !isPlaceholder(BuiltinType::Overload)) {
     SrcExpr = Self.DefaultFunctionArrayLvalueConversion(SrcExpr.get());

diff  --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 4abd870ad6aaa..9a60054a6169e 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2717,6 +2717,9 @@ bool SemaHLSL::CheckCompatibleParameterABI(FunctionDecl *New,
 // clarity of what types are supported
 bool SemaHLSL::CanPerformScalarCast(QualType SrcTy, QualType DestTy) {
 
+  if (!SrcTy->isScalarType() || !DestTy->isScalarType())
+    return false;
+
   if (SemaRef.getASTContext().hasSameUnqualifiedType(SrcTy, DestTy))
     return true;
 
@@ -2778,7 +2781,7 @@ bool SemaHLSL::CanPerformScalarCast(QualType SrcTy, QualType DestTy) {
 }
 
 // Detect if a type contains a bitfield. Will be removed when
-// bitfield support is added to HLSLElementwiseCast
+// bitfield support is added to HLSLElementwiseCast and HLSLAggregateSplatCast
 bool SemaHLSL::ContainsBitField(QualType BaseTy) {
   llvm::SmallVector<QualType, 16> WorkList;
   WorkList.push_back(BaseTy);
@@ -2811,6 +2814,42 @@ bool SemaHLSL::ContainsBitField(QualType BaseTy) {
   return false;
 }
 
+// Can perform an HLSL Aggregate splat cast if the Dest is an aggregate and the
+// Src is a scalar or a vector of length 1
+// Or if Dest is a vector and Src is a vector of length 1
+bool SemaHLSL::CanPerformAggregateSplatCast(Expr *Src, QualType DestTy) {
+
+  QualType SrcTy = Src->getType();
+  // Not a valid HLSL Aggregate Splat cast if Dest is a scalar or if this is
+  // going to be a vector splat from a scalar.
+  if ((SrcTy->isScalarType() && DestTy->isVectorType()) ||
+      DestTy->isScalarType())
+    return false;
+
+  const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();
+
+  // Src isn't a scalar or a vector of length 1
+  if (!SrcTy->isScalarType() && !(SrcVecTy && SrcVecTy->getNumElements() == 1))
+    return false;
+
+  if (SrcVecTy)
+    SrcTy = SrcVecTy->getElementType();
+
+  if (ContainsBitField(DestTy))
+    return false;
+
+  llvm::SmallVector<QualType> DestTypes;
+  BuildFlattenedTypeList(DestTy, DestTypes);
+
+  for (unsigned I = 0, Size = DestTypes.size(); I < Size; ++I) {
+    if (DestTypes[I]->isUnionType())
+      return false;
+    if (!CanPerformScalarCast(SrcTy, DestTypes[I]))
+      return false;
+  }
+  return true;
+}
+
 // Can we perform an HLSL Elementwise cast?
 // TODO: update this code when matrices are added; see issue #88060
 bool SemaHLSL::CanPerformElementwiseCast(Expr *Src, QualType DestTy) {

diff  --git a/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp b/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp
index 3a983421358c7..1061dafbb2473 100644
--- a/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp
+++ b/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp
@@ -523,6 +523,7 @@ void ExprEngine::VisitCast(const CastExpr *CastE, const Expr *Ex,
       case CK_MatrixCast:
       case CK_VectorSplat:
       case CK_HLSLElementwiseCast:
+      case CK_HLSLAggregateSplatCast:
       case CK_HLSLVectorTruncation: {
         QualType resultType = CastE->getType();
         if (CastE->isGLValue())

diff  --git a/clang/test/CodeGenHLSL/BasicFeatures/AggregateSplatCast.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/AggregateSplatCast.hlsl
new file mode 100644
index 0000000000000..42b6abec1b3d8
--- /dev/null
+++ b/clang/test/CodeGenHLSL/BasicFeatures/AggregateSplatCast.hlsl
@@ -0,0 +1,87 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s
+
+// array splat
+// CHECK-LABEL: define void {{.*}}call4
+// CHECK: [[B:%.*]] = alloca [2 x i32], align 4
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false)
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 1
+// CHECK-NEXT: store i32 3, ptr [[G1]], align 4
+// CHECK-NEXT: store i32 3, ptr [[G2]], align 4
+export void call4() {
+  int B[2] = {1,2};
+  B = (int[2])3;
+}
+
+// splat from vector of length 1
+// CHECK-LABEL: define void {{.*}}call8
+// CHECK: [[A:%.*]] = alloca <1 x i32>, align 4
+// CHECK-NEXT: [[B:%.*]] = alloca [2 x i32], align 4
+// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false)
+// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
+// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i32 0
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 1
+// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
+// CHECK-NEXT: store i32 [[VL]], ptr [[G2]], align 4
+export void call8() {
+  int1 A = {1};
+  int B[2] = {1,2};
+  B = (int[2])A;
+}
+
+// vector splat from vector of length 1
+// CHECK-LABEL: define void {{.*}}call1
+// CHECK: [[B:%.*]] = alloca <1 x float>, align 4
+// CHECK-NEXT: [[A:%.*]] = alloca <4 x i32>, align 16
+// CHECK-NEXT: store <1 x float> splat (float 1.000000e+00), ptr [[B]], align 4
+// CHECK-NEXT: [[L:%.*]] = load <1 x float>, ptr [[B]], align 4
+// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x float> [[L]], i32 0
+// CHECK-NEXT: [[C:%.*]] = fptosi float [[VL]] to i32
+// CHECK-NEXT: [[SI:%.*]] = insertelement <4 x i32> poison, i32 [[C]], i64 0
+// CHECK-NEXT: [[S:%.*]] = shufflevector <4 x i32> [[SI]], <4 x i32> poison, <4 x i32> zeroinitializer
+// CHECK-NEXT: store <4 x i32> [[S]], ptr [[A]], align 16
+export void call1() {
+  float1 B = {1.0};
+  int4 A = (int4)B;
+}
+
+struct S {
+  int X;
+  float Y;
+};
+
+// struct splats
+// CHECK-LABEL: define void {{.*}}call3
+// CHECK: [[A:%.*]] = alloca <1 x i32>, align 4
+// CHECK: [[s:%.*]] = alloca %struct.S, align 4
+// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
+// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
+// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i32 0
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 1
+// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
+// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float
+// CHECK-NEXT: store float [[C]], ptr [[G2]], align 4
+export void call3() {
+  int1 A = {1};
+  S s = (S)A;
+}
+
+// struct splat from vector of length 1
+// CHECK-LABEL: define void {{.*}}call5
+// CHECK: [[A:%.*]] = alloca <1 x i32>, align 4
+// CHECK-NEXT: [[s:%.*]] = alloca %struct.S, align 4
+// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
+// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
+// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i32 0
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 1
+// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
+// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float
+// CHECK-NEXT: store float [[C]], ptr [[G2]], align 4
+export void call5() {
+  int1 A = {1};
+  S s = (S)A;
+}

diff  --git a/clang/test/SemaHLSL/Language/AggregateSplatCast-errors.hlsl b/clang/test/SemaHLSL/Language/AggregateSplatCast-errors.hlsl
new file mode 100644
index 0000000000000..662dae27e8200
--- /dev/null
+++ b/clang/test/SemaHLSL/Language/AggregateSplatCast-errors.hlsl
@@ -0,0 +1,49 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -verify -verify-ignore-unexpected=note
+
+struct S {
+  int A : 8;
+  int B;
+};
+
+struct R {
+  int A;
+  union {
+    float F;
+    int4 G;
+  };
+};
+
+// casting types which contain bitfields is not yet supported.
+export void cantCast() {
+  S s = (S)1;
+  // expected-error at -1 {{no matching conversion for C-style cast from 'int' to 'S'}}
+}
+
+// Can't cast a union
+export void cantCast2() {
+  R r = (R)1;
+  // expected-error at -1 {{no matching conversion for C-style cast from 'int' to 'R'}}
+}
+
+RWBuffer<float4> Buf;
+
+// Can't cast an intangible type
+export void cantCast3() {
+  Buf = (RWBuffer<float4>)1;
+  // expected-error at -1 {{no matching conversion for C-style cast from 'int' to 'RWBuffer<float4>' (aka 'RWBuffer<vector<float, 4>>')}}
+}
+
+export void cantCast4() {
+ RWBuffer<float4> B[2] = (RWBuffer<float4>[2])1;
+ // expected-error at -1 {{C-style cast from 'int' to 'RWBuffer<float4>[2]' (aka 'RWBuffer<vector<float, 4>>[2]') is not allowed}}
+}
+
+struct X {
+  int A;
+  RWBuffer<float4> Buf;
+};
+
+export void cantCast5() {
+  X x = (X)1;
+  // expected-error at -1 {{no matching conversion for C-style cast from 'int' to 'X'}}
+}
\ No newline at end of file

diff  --git a/clang/test/SemaHLSL/Language/AggregateSplatCasts.hlsl b/clang/test/SemaHLSL/Language/AggregateSplatCasts.hlsl
new file mode 100644
index 0000000000000..e5a851c50caf8
--- /dev/null
+++ b/clang/test/SemaHLSL/Language/AggregateSplatCasts.hlsl
@@ -0,0 +1,27 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -finclude-default-header -fnative-half-type %s -ast-dump | FileCheck %s
+
+// splat from vec1 to vec
+// CHECK-LABEL: call1
+// CHECK: CStyleCastExpr {{.*}} 'int3':'vector<int, 3>' <HLSLAggregateSplatCast>
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'int' lvalue <FloatingToIntegral> part_of_explicit_cast
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'float' lvalue <HLSLVectorTruncation> part_of_explicit_cast
+// CHECK-NEXT: DeclRefExpr {{.*}} 'float1':'vector<float, 1>' lvalue Var {{.*}} 'A' 'float1':'vector<float, 1>'
+export void call1() {
+  float1 A = {1.0};
+  int3 B = (int3)A;
+}
+
+struct S {
+  int A;
+  float B;
+  int C;
+  float D;
+};
+
+// splat from scalar to aggregate
+// CHECK-LABEL: call2
+// CHECK: CStyleCastExpr {{.*}} 'S' <HLSLAggregateSplatCast>
+// CHECK-NEXt: IntegerLiteral {{.*}} 'int' 5
+export void call2() {
+  S s = (S)5; 
+}

diff  --git a/clang/test/SemaHLSL/Language/ElementwiseCast-errors.hlsl b/clang/test/SemaHLSL/Language/ElementwiseCast-errors.hlsl
index c900c83a063a0..b7085bc69547b 100644
--- a/clang/test/SemaHLSL/Language/ElementwiseCast-errors.hlsl
+++ b/clang/test/SemaHLSL/Language/ElementwiseCast-errors.hlsl
@@ -27,3 +27,23 @@ export void cantCast3() {
   S s = (S)C;
   // expected-error at -1 {{no matching conversion for C-style cast from 'int2' (aka 'vector<int, 2>') to 'S'}}
 }
+
+struct R {
+// expected-note at -1 {{candidate constructor (the implicit copy constructor) not viable: no known conversion from 'int2' (aka 'vector<int, 2>') to 'const R' for 1st argument}}
+// expected-note at -2 {{candidate constructor (the implicit move constructor) not viable: no known conversion from 'int2' (aka 'vector<int, 2>') to 'R' for 1st argument}}
+// expected-note at -3 {{candidate constructor (the implicit default constructor) not viable: requires 0 arguments, but 1 was provided}}
+  int A;
+  union {
+    float F;
+    int4 G;
+  };
+};
+
+export void cantCast4() {
+  int2 A = {1,2};
+  R r = R(A);
+  // expected-error at -1 {{no matching conversion for functional-style cast from 'int2' (aka 'vector<int, 2>') to 'R'}}
+  R r2 = {1, 2};
+  int2 B = (int2)r2;
+  // expected-error at -1 {{cannot convert 'R' to 'int2' (aka 'vector<int, 2>') without a conversion operator}}
+}


        


More information about the cfe-commits mailing list