[clang] [HLSL] Implement HLSL splatting (PR #118992)
Sarah Spall via cfe-commits
cfe-commits at lists.llvm.org
Sat Feb 8 09:07:54 PST 2025
https://github.com/spall updated https://github.com/llvm/llvm-project/pull/118992
>From e994824f3630ee8b224afceb6c14d980c9013112 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Fri, 6 Dec 2024 05:14:17 +0000
Subject: [PATCH 1/9] splat cast wip
---
clang/include/clang/AST/OperationKinds.def | 3 ++
clang/include/clang/Sema/SemaHLSL.h | 1 +
clang/lib/CodeGen/CGExprAgg.cpp | 42 ++++++++++++++++++++++
clang/lib/CodeGen/CGExprScalar.cpp | 16 +++++++++
clang/lib/Sema/Sema.cpp | 1 +
clang/lib/Sema/SemaCast.cpp | 9 ++++-
clang/lib/Sema/SemaHLSL.cpp | 26 ++++++++++++++
7 files changed, 97 insertions(+), 1 deletion(-)
diff --git a/clang/include/clang/AST/OperationKinds.def b/clang/include/clang/AST/OperationKinds.def
index b3dc7c3d8dc77e1..333fc7e1b18821e 100644
--- a/clang/include/clang/AST/OperationKinds.def
+++ b/clang/include/clang/AST/OperationKinds.def
@@ -370,6 +370,9 @@ CAST_OPERATION(HLSLArrayRValue)
// Aggregate by Value cast (HLSL only).
CAST_OPERATION(HLSLElementwiseCast)
+// Splat cast for Aggregates (HLSL only).
+CAST_OPERATION(HLSLSplatCast)
+
//===- Binary Operations -------------------------------------------------===//
// Operators listed in order of precedence.
// Note that additions to this should also update the StmtVisitor class,
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index 6e8ca2e4710dec8..7508b149b0d81d0 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -144,6 +144,7 @@ class SemaHLSL : public SemaBase {
bool CanPerformScalarCast(QualType SrcTy, QualType DestTy);
bool ContainsBitField(QualType BaseTy);
bool CanPerformElementwiseCast(Expr *Src, QualType DestType);
+ bool CanPerformSplat(Expr *Src, QualType DestType);
ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg);
QualType getInoutParameterType(QualType Ty);
diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index c3f1cbed6b39f95..f26189bc4907cea 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -491,6 +491,33 @@ static bool isTrivialFiller(Expr *E) {
return false;
}
+static void EmitHLSLSplatCast(CodeGenFunction &CGF, Address DestVal,
+ QualType DestTy, llvm::Value *SrcVal,
+ QualType SrcTy, SourceLocation Loc) {
+ // Flatten our destination
+ SmallVector<QualType> DestTypes; // Flattened type
+ SmallVector<llvm::Value *, 4> IdxList;
+ SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
+ // ^^ Flattened accesses to DestVal we want to store into
+ CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList,
+ DestTypes);
+
+ if (const VectorType *VT = SrcTy->getAs<VectorType>()) {
+ assert(VT->getNumElements() == 1 && "Invalid HLSL splat cast.");
+
+ SrcTy = VT->getElementType();
+ SrcVal = CGF.Builder.CreateExtractElement(SrcVal, (uint64_t)0,
+ "vec.load");
+ }
+ assert(SrcTy->isScalarType() && "Invalid HLSL splat cast.");
+ for(unsigned i = 0; i < StoreGEPList.size(); i ++) {
+ llvm::Value *Cast = CGF.EmitScalarConversion(SrcVal, SrcTy,
+ DestTypes[i],
+ Loc);
+ CGF.PerformStore(StoreGEPList[i], Cast);
+ }
+}
+
// emit a flat cast where the RHS is a scalar, including vector
static void EmitHLSLScalarFlatCast(CodeGenFunction &CGF, Address DestVal,
QualType DestTy, llvm::Value *SrcVal,
@@ -963,6 +990,21 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
case CK_HLSLArrayRValue:
Visit(E->getSubExpr());
break;
+ case CK_HLSLSplatCast: {
+ Expr *Src = E->getSubExpr();
+ QualType SrcTy = Src->getType();
+ RValue RV = CGF.EmitAnyExpr(Src);
+ QualType DestTy = E->getType();
+ Address DestVal = Dest.getAddress();
+ SourceLocation Loc = E->getExprLoc();
+
+ if (RV.isScalar()) {
+ llvm::Value *SrcVal = RV.getScalarVal();
+ EmitHLSLSplatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
+ break;
+ }
+ llvm_unreachable("RHS of HLSL splat cast must be a scalar or vector.");
+ }
case CK_HLSLElementwiseCast: {
Expr *Src = E->getSubExpr();
QualType SrcTy = Src->getType();
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 80daed7e5395193..7dc2682bae42f2e 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -2795,6 +2795,22 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
llvm::Value *Zero = llvm::Constant::getNullValue(CGF.SizeTy);
return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc");
}
+ case CK_HLSLSplatCast: {
+ assert(DestTy->isVectorType() && "Destination type must be a vector.");
+ auto *DestVecTy = DestTy->getAs<VectorType>();
+ QualType SrcTy = E->getType();
+ SourceLocation Loc = CE->getExprLoc();
+ Value *V = Visit(const_cast<Expr *>(E));
+ if (auto *VecTy = SrcTy->getAs<VectorType>()) {
+ assert(VecTy->getNumElements() == 1 && "Invalid HLSL splat cast.");
+ V = CGF.Builder.CreateExtractElement(V, (uint64_t)0, "vec.load");
+ SrcTy = VecTy->getElementType();
+ }
+ assert(SrcTy->isScalarType() && "Invalid HLSL splat cast.");
+ Value *Cast = EmitScalarConversion(V, SrcTy,
+ DestVecTy->getElementType(), Loc);
+ return Builder.CreateVectorSplat(DestVecTy->getNumElements(), Cast, "splat");
+ }
case CK_HLSLElementwiseCast: {
RValue RV = CGF.EmitAnyExpr(E);
SourceLocation Loc = CE->getExprLoc();
diff --git a/clang/lib/Sema/Sema.cpp b/clang/lib/Sema/Sema.cpp
index 15c18f9a4525b22..9eeefbb3c002329 100644
--- a/clang/lib/Sema/Sema.cpp
+++ b/clang/lib/Sema/Sema.cpp
@@ -709,6 +709,7 @@ ExprResult Sema::ImpCastExprToType(Expr *E, QualType Ty,
case CK_ToVoid:
case CK_NonAtomicToAtomic:
case CK_HLSLArrayRValue:
+ case CK_HLSLSplatCast:
break;
}
}
diff --git a/clang/lib/Sema/SemaCast.cpp b/clang/lib/Sema/SemaCast.cpp
index 23be71ad8e2aebc..56d8396b1e9d41a 100644
--- a/clang/lib/Sema/SemaCast.cpp
+++ b/clang/lib/Sema/SemaCast.cpp
@@ -2776,9 +2776,16 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
CheckedConversionKind CCK = FunctionalStyle
? CheckedConversionKind::FunctionalCast
: CheckedConversionKind::CStyleCast;
+
// This case should not trigger on regular vector splat
- // vector cast, vector truncation, or special hlsl splat cases
QualType SrcTy = SrcExpr.get()->getType();
+ if (Self.getLangOpts().HLSL &&
+ Self.HLSL().CanPerformSplat(SrcExpr.get(), DestType)) {
+ Kind = CK_HLSLSplatCast;
+ return;
+ }
+
+ // This case should not trigger on regular vector cast, vector truncation
if (Self.getLangOpts().HLSL &&
Self.HLSL().CanPerformElementwiseCast(SrcExpr.get(), DestType)) {
if (SrcTy->isConstantArrayType())
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index ec6b5b45de42bfa..7c9365787fd4fb5 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2804,6 +2804,32 @@ bool SemaHLSL::ContainsBitField(QualType BaseTy) {
return false;
}
+// Can perform an HLSL splat cast if the Dest is an aggregate and the
+// Src is a scalar or a vector of length 1
+// Or if Dest is a vector and Src is a vector of length 1
+bool SemaHLSL::CanPerformSplat(Expr *Src, QualType DestTy) {
+
+ QualType SrcTy = Src->getType();
+ if (SrcTy->isScalarType() && DestTy->isVectorType())
+ return false;
+
+ const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();
+ if (!(SrcTy->isScalarType() || (SrcVecTy && SrcVecTy->getNumElements() == 1)))
+ return false;
+
+ if (SrcVecTy)
+ SrcTy = SrcVecTy->getElementType();
+
+ llvm::SmallVector<QualType> DestTypes;
+ BuildFlattenedTypeList(DestTy, DestTypes);
+
+ for(unsigned i = 0; i < DestTypes.size(); i ++) {
+ if (!CanPerformScalarCast(SrcTy, DestTypes[i]))
+ return false;
+ }
+ return true;
+}
+
// Can we perform an HLSL Elementwise cast?
// TODO: update this code when matrices are added; see issue #88060
bool SemaHLSL::CanPerformElementwiseCast(Expr *Src, QualType DestTy) {
>From 24bea86dd7a2c39ca9f21480990236dc44df8cf3 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Fri, 6 Dec 2024 05:19:00 +0000
Subject: [PATCH 2/9] make clang format happy
---
clang/lib/CodeGen/CGExprAgg.cpp | 19 ++++++++-----------
clang/lib/CodeGen/CGExprScalar.cpp | 7 ++++---
clang/lib/Sema/SemaHLSL.cpp | 2 +-
3 files changed, 13 insertions(+), 15 deletions(-)
diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index f26189bc4907cea..60beabf3a5fd0aa 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -492,28 +492,25 @@ static bool isTrivialFiller(Expr *E) {
}
static void EmitHLSLSplatCast(CodeGenFunction &CGF, Address DestVal,
- QualType DestTy, llvm::Value *SrcVal,
- QualType SrcTy, SourceLocation Loc) {
+ QualType DestTy, llvm::Value *SrcVal,
+ QualType SrcTy, SourceLocation Loc) {
// Flatten our destination
SmallVector<QualType> DestTypes; // Flattened type
SmallVector<llvm::Value *, 4> IdxList;
SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
// ^^ Flattened accesses to DestVal we want to store into
- CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList,
- DestTypes);
+ CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList, DestTypes);
if (const VectorType *VT = SrcTy->getAs<VectorType>()) {
assert(VT->getNumElements() == 1 && "Invalid HLSL splat cast.");
SrcTy = VT->getElementType();
- SrcVal = CGF.Builder.CreateExtractElement(SrcVal, (uint64_t)0,
- "vec.load");
+ SrcVal = CGF.Builder.CreateExtractElement(SrcVal, (uint64_t)0, "vec.load");
}
assert(SrcTy->isScalarType() && "Invalid HLSL splat cast.");
- for(unsigned i = 0; i < StoreGEPList.size(); i ++) {
- llvm::Value *Cast = CGF.EmitScalarConversion(SrcVal, SrcTy,
- DestTypes[i],
- Loc);
+ for (unsigned i = 0; i < StoreGEPList.size(); i++) {
+ llvm::Value *Cast =
+ CGF.EmitScalarConversion(SrcVal, SrcTy, DestTypes[i], Loc);
CGF.PerformStore(StoreGEPList[i], Cast);
}
}
@@ -997,7 +994,7 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
QualType DestTy = E->getType();
Address DestVal = Dest.getAddress();
SourceLocation Loc = E->getExprLoc();
-
+
if (RV.isScalar()) {
llvm::Value *SrcVal = RV.getScalarVal();
EmitHLSLSplatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 7dc2682bae42f2e..4a20b693b101fae 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -2807,9 +2807,10 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
SrcTy = VecTy->getElementType();
}
assert(SrcTy->isScalarType() && "Invalid HLSL splat cast.");
- Value *Cast = EmitScalarConversion(V, SrcTy,
- DestVecTy->getElementType(), Loc);
- return Builder.CreateVectorSplat(DestVecTy->getNumElements(), Cast, "splat");
+ Value *Cast =
+ EmitScalarConversion(V, SrcTy, DestVecTy->getElementType(), Loc);
+ return Builder.CreateVectorSplat(DestVecTy->getNumElements(), Cast,
+ "splat");
}
case CK_HLSLElementwiseCast: {
RValue RV = CGF.EmitAnyExpr(E);
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 7c9365787fd4fb5..024f778f8ffef5b 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2823,7 +2823,7 @@ bool SemaHLSL::CanPerformSplat(Expr *Src, QualType DestTy) {
llvm::SmallVector<QualType> DestTypes;
BuildFlattenedTypeList(DestTy, DestTypes);
- for(unsigned i = 0; i < DestTypes.size(); i ++) {
+ for (unsigned i = 0; i < DestTypes.size(); i++) {
if (!CanPerformScalarCast(SrcTy, DestTypes[i]))
return false;
}
>From 3575617d436f04eac4faadc17ead8bfe561e7e7c Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Fri, 6 Dec 2024 05:59:12 +0000
Subject: [PATCH 3/9] codegen test
---
.../CodeGenHLSL/BasicFeatures/SplatCast.hlsl | 87 +++++++++++++++++++
1 file changed, 87 insertions(+)
create mode 100644 clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl
diff --git a/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl
new file mode 100644
index 000000000000000..05359c1bce0ba35
--- /dev/null
+++ b/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl
@@ -0,0 +1,87 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s
+
+// array splat
+// CHECK-LABEL: define void {{.*}}call4
+// CHECK: [[B:%.*]] = alloca [2 x i32], align 4
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false)
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 1
+// CHECK-NEXT: store i32 3, ptr [[G1]], align 4
+// CHECK-NEXT: store i32 3, ptr [[G2]], align 4
+export void call4() {
+ int B[2] = {1,2};
+ B = (int[2])3;
+}
+
+// splat from vector of length 1
+// CHECK-LABEL: define void {{.*}}call8
+// CHECK: [[A:%.*]] = alloca <1 x i32>, align 4
+// CHECK-NEXT: [[B:%.*]] = alloca [2 x i32], align 4
+// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false)
+// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 1
+// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0
+// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
+// CHECK-NEXT: store i32 [[VL]], ptr [[G2]], align 4
+export void call8() {
+ int1 A = {1};
+ int B[2] = {1,2};
+ B = (int[2])A;
+}
+
+// vector splat from vector of length 1
+// CHECK-LABEL: define void {{.*}}call1
+// CHECK: [[B:%.*]] = alloca <1 x float>, align 4
+// CHECK-NEXT: [[A:%.*]] = alloca <4 x i32>, align 16
+// CHECK-NEXT: store <1 x float> splat (float 1.000000e+00), ptr [[B]], align 4
+// CHECK-NEXT: [[L:%.*]] = load <1 x float>, ptr [[B]], align 4
+// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x float> [[L]], i64 0
+// CHECK-NEXT: [[C:%.*]] = fptosi float [[VL]] to i32
+// CHECK-NEXT: [[SI:%.*]] = insertelement <4 x i32> poison, i32 [[C]], i64 0
+// CHECK-NEXT: [[S:%.*]] = shufflevector <4 x i32> [[SI]], <4 x i32> poison, <4 x i32> zeroinitializer
+// CHECK-NEXT: store <4 x i32> [[S]], ptr [[A]], align 16
+export void call1() {
+ float1 B = {1.0};
+ int4 A = (int4)B;
+}
+
+struct S {
+ int X;
+ float Y;
+};
+
+// struct splats?
+// CHECK-LABEL: define void {{.*}}call3
+// CHECK: [[A:%.*]] = alloca <1 x i32>, align 4
+// CHECK: [[s:%.*]] = alloca %struct.S, align 4
+// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
+// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 1
+// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0
+// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
+// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float
+// CHECK-NEXT: store float [[C]], ptr [[G2]], align 4
+export void call3() {
+ int1 A = {1};
+ S s = (S)A;
+}
+
+// struct splat from vector of length 1
+// CHECK-LABEL: define void {{.*}}call5
+// CHECK: [[A:%.*]] = alloca <1 x i32>, align 4
+// CHECK-NEXT: [[s:%.*]] = alloca %struct.S, align 4
+// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
+// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 1
+// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0
+// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
+// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float
+// CHECK-NEXT: store float [[C]], ptr [[G2]], align 4
+export void call5() {
+ int1 A = {1};
+ S s = (S)A;
+}
>From 288b8dac1c6fa4429c92c566a69da593c2ebb97c Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Fri, 6 Dec 2024 17:38:58 +0000
Subject: [PATCH 4/9] Try to handle Cast in all the places it needs to be
handled
---
clang/lib/AST/Expr.cpp | 1 +
clang/lib/AST/ExprConstant.cpp | 2 ++
clang/lib/CodeGen/CGExprAgg.cpp | 1 +
clang/lib/CodeGen/CGExprComplex.cpp | 1 +
clang/lib/CodeGen/CGExprConstant.cpp | 1 +
clang/lib/Edit/RewriteObjCFoundationAPI.cpp | 1 +
clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp | 1 +
7 files changed, 8 insertions(+)
diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp
index c22aa66ba2cfb3d..bbb475fbb30f269 100644
--- a/clang/lib/AST/Expr.cpp
+++ b/clang/lib/AST/Expr.cpp
@@ -1957,6 +1957,7 @@ bool CastExpr::CastConsistency() const {
case CK_HLSLArrayRValue:
case CK_HLSLVectorTruncation:
case CK_HLSLElementwiseCast:
+ case CK_HLSLSplatCast:
CheckNoBasePath:
assert(path_empty() && "Cast kind should not have a base path!");
break;
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 192b679b4c99596..ddc2d008839007e 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -15029,6 +15029,7 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
case CK_FixedPointCast:
case CK_IntegralToFixedPoint:
case CK_MatrixCast:
+ case CK_HLSLSplatCast:
llvm_unreachable("invalid cast kind for integral value");
case CK_BitCast:
@@ -15907,6 +15908,7 @@ bool ComplexExprEvaluator::VisitCastExpr(const CastExpr *E) {
case CK_MatrixCast:
case CK_HLSLVectorTruncation:
case CK_HLSLElementwiseCast:
+ case CK_HLSLSplatCast:
llvm_unreachable("invalid cast kind for complex value");
case CK_LValueToRValue:
diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index 60beabf3a5fd0aa..3584280e2fb9e44 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -1592,6 +1592,7 @@ static bool castPreservesZero(const CastExpr *CE) {
case CK_AtomicToNonAtomic:
case CK_HLSLVectorTruncation:
case CK_HLSLElementwiseCast:
+ // TODO is this true for CK_HLSLSplatCast
return true;
case CK_BaseToDerivedMemberPointer:
diff --git a/clang/lib/CodeGen/CGExprComplex.cpp b/clang/lib/CodeGen/CGExprComplex.cpp
index c2679ea92dc9728..3832b9b598b24e9 100644
--- a/clang/lib/CodeGen/CGExprComplex.cpp
+++ b/clang/lib/CodeGen/CGExprComplex.cpp
@@ -611,6 +611,7 @@ ComplexPairTy ComplexExprEmitter::EmitCast(CastKind CK, Expr *Op,
case CK_HLSLVectorTruncation:
case CK_HLSLArrayRValue:
case CK_HLSLElementwiseCast:
+ case CK_HLSLSplatCast:
llvm_unreachable("invalid cast kind for complex value");
case CK_FloatingRealToComplex:
diff --git a/clang/lib/CodeGen/CGExprConstant.cpp b/clang/lib/CodeGen/CGExprConstant.cpp
index ef11798869d3b13..b8ce83803b65fde 100644
--- a/clang/lib/CodeGen/CGExprConstant.cpp
+++ b/clang/lib/CodeGen/CGExprConstant.cpp
@@ -1336,6 +1336,7 @@ class ConstExprEmitter
case CK_HLSLVectorTruncation:
case CK_HLSLArrayRValue:
case CK_HLSLElementwiseCast:
+ case CK_HLSLSplatCast:
return nullptr;
}
llvm_unreachable("Invalid CastKind");
diff --git a/clang/lib/Edit/RewriteObjCFoundationAPI.cpp b/clang/lib/Edit/RewriteObjCFoundationAPI.cpp
index 32f5ebb55155ed1..10d3f62fcd0a416 100644
--- a/clang/lib/Edit/RewriteObjCFoundationAPI.cpp
+++ b/clang/lib/Edit/RewriteObjCFoundationAPI.cpp
@@ -1086,6 +1086,7 @@ static bool rewriteToNumericBoxedExpression(const ObjCMessageExpr *Msg,
case CK_HLSLVectorTruncation:
case CK_HLSLElementwiseCast:
+ case CK_HLSLSplatCast:
llvm_unreachable("HLSL-specific cast in Objective-C?");
break;
diff --git a/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp b/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp
index 3a983421358c7f4..d75583f68eb6b7b 100644
--- a/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp
+++ b/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp
@@ -523,6 +523,7 @@ void ExprEngine::VisitCast(const CastExpr *CastE, const Expr *Ex,
case CK_MatrixCast:
case CK_VectorSplat:
case CK_HLSLElementwiseCast:
+ case CK_HLSLSplatCast:
case CK_HLSLVectorTruncation: {
QualType resultType = CastE->getType();
if (CastE->isGLValue())
>From 0650840642960d950d64e234e9641e34096a6c55 Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Wed, 11 Dec 2024 20:54:39 +0000
Subject: [PATCH 5/9] get code compiling after rebase
---
clang/lib/CodeGen/CGExprAgg.cpp | 13 ++++++++++---
1 file changed, 10 insertions(+), 3 deletions(-)
diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index 3584280e2fb9e44..3330cd03628f75e 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -496,10 +496,9 @@ static void EmitHLSLSplatCast(CodeGenFunction &CGF, Address DestVal,
QualType SrcTy, SourceLocation Loc) {
// Flatten our destination
SmallVector<QualType> DestTypes; // Flattened type
- SmallVector<llvm::Value *, 4> IdxList;
SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
// ^^ Flattened accesses to DestVal we want to store into
- CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList, DestTypes);
+ CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes);
if (const VectorType *VT = SrcTy->getAs<VectorType>()) {
assert(VT->getNumElements() == 1 && "Invalid HLSL splat cast.");
@@ -511,7 +510,15 @@ static void EmitHLSLSplatCast(CodeGenFunction &CGF, Address DestVal,
for (unsigned i = 0; i < StoreGEPList.size(); i++) {
llvm::Value *Cast =
CGF.EmitScalarConversion(SrcVal, SrcTy, DestTypes[i], Loc);
- CGF.PerformStore(StoreGEPList[i], Cast);
+
+ // store back
+ llvm::Value *Idx = StoreGEPList[i].second;
+ if (Idx) {
+ llvm::Value *V =
+ CGF.Builder.CreateLoad(StoreGEPList[i].first, "load.for.insert");
+ Cast = CGF.Builder.CreateInsertElement(V, Cast, Idx);
+ }
+ CGF.Builder.CreateStore(Cast, StoreGEPList[i].first);
}
}
>From f924b13ada0c3344f3cc4f87a859f0ecd16705cb Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Thu, 12 Dec 2024 00:04:29 +0000
Subject: [PATCH 6/9] Self review
---
clang/lib/CodeGen/CGExprScalar.cpp | 15 +++++++-----
clang/lib/Sema/SemaHLSL.cpp | 7 +++---
clang/test/SemaHLSL/Language/SplatCasts.hlsl | 25 ++++++++++++++++++++
3 files changed, 38 insertions(+), 9 deletions(-)
create mode 100644 clang/test/SemaHLSL/Language/SplatCasts.hlsl
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 4a20b693b101fae..85c0265ea14b611 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -2796,17 +2796,20 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc");
}
case CK_HLSLSplatCast: {
+ // This code should only handle splatting from vectors of length 1.
assert(DestTy->isVectorType() && "Destination type must be a vector.");
auto *DestVecTy = DestTy->getAs<VectorType>();
QualType SrcTy = E->getType();
SourceLocation Loc = CE->getExprLoc();
Value *V = Visit(const_cast<Expr *>(E));
- if (auto *VecTy = SrcTy->getAs<VectorType>()) {
- assert(VecTy->getNumElements() == 1 && "Invalid HLSL splat cast.");
- V = CGF.Builder.CreateExtractElement(V, (uint64_t)0, "vec.load");
- SrcTy = VecTy->getElementType();
- }
- assert(SrcTy->isScalarType() && "Invalid HLSL splat cast.");
+ assert(SrcTy->isVectorType() && "Invalid HLSL splat cast.");
+
+ auto *VecTy = SrcTy->getAs<VectorType>();
+ assert(VecTy->getNumElements() == 1 && "Invalid HLSL splat cast.");
+
+ V = CGF.Builder.CreateExtractElement(V, (uint64_t)0, "vec.load");
+ SrcTy = VecTy->getElementType();
+
Value *Cast =
EmitScalarConversion(V, SrcTy, DestVecTy->getElementType(), Loc);
return Builder.CreateVectorSplat(DestVecTy->getNumElements(), Cast,
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 024f778f8ffef5b..432a42016789ec2 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2814,12 +2814,13 @@ bool SemaHLSL::CanPerformSplat(Expr *Src, QualType DestTy) {
return false;
const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();
- if (!(SrcTy->isScalarType() || (SrcVecTy && SrcVecTy->getNumElements() == 1)))
- return false;
-
if (SrcVecTy)
SrcTy = SrcVecTy->getElementType();
+ // Src isn't a scalar or a vector of length 1
+ if (!SrcTy->isScalarType() && !(SrcVecTy && SrcVecTy->getNumElements() == 1))
+ return false;
+
llvm::SmallVector<QualType> DestTypes;
BuildFlattenedTypeList(DestTy, DestTypes);
diff --git a/clang/test/SemaHLSL/Language/SplatCasts.hlsl b/clang/test/SemaHLSL/Language/SplatCasts.hlsl
new file mode 100644
index 000000000000000..593a8e67fd4a3b8
--- /dev/null
+++ b/clang/test/SemaHLSL/Language/SplatCasts.hlsl
@@ -0,0 +1,25 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -finclude-default-header -fnative-half-type %s -ast-dump | FileCheck %s
+
+// splat from vec1 to vec
+// CHECK-LABEL: call1
+// CHECK: CStyleCastExpr {{.*}} 'int3':'vector<int, 3>' <HLSLSplatCast>
+// CHECK-NEXT: DeclRefExpr {{.*}} 'float1':'vector<float, 1>' lvalue Var {{.*}} 'A' 'float1':'vector<float, 1>'
+export void call1() {
+ float1 A = {1.0};
+ int3 B = (int3)A;
+}
+
+struct S {
+ int A;
+ float B;
+ int C;
+ float D;
+};
+
+// splat from scalar to aggregate
+// CHECK-LABEL: call2
+// CHECK: CStyleCastExpr {{.*}} 'S' <HLSLSplatCast>
+// CHECK-NEXt: IntegerLiteral {{.*}} 'int' 5
+export void call2() {
+ S s = (S)5;
+}
\ No newline at end of file
>From 89ceeb7d6b445f10fa6b7deb8c10267cd292da7b Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Thu, 12 Dec 2024 05:59:55 +0000
Subject: [PATCH 7/9] move code back that broke tests
---
clang/lib/Sema/SemaHLSL.cpp | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 432a42016789ec2..76ca24b10c60a16 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2814,13 +2814,14 @@ bool SemaHLSL::CanPerformSplat(Expr *Src, QualType DestTy) {
return false;
const VectorType *SrcVecTy = SrcTy->getAs<VectorType>();
- if (SrcVecTy)
- SrcTy = SrcVecTy->getElementType();
// Src isn't a scalar or a vector of length 1
if (!SrcTy->isScalarType() && !(SrcVecTy && SrcVecTy->getNumElements() == 1))
return false;
+ if (SrcVecTy)
+ SrcTy = SrcVecTy->getElementType();
+
llvm::SmallVector<QualType> DestTypes;
BuildFlattenedTypeList(DestTy, DestTypes);
>From 7f5b3e4f39f2a4cf2d42e5281e70d900878c1a3b Mon Sep 17 00:00:00 2001
From: Sarah Spall <spall at planetbauer.com>
Date: Thu, 12 Dec 2024 06:08:46 +0000
Subject: [PATCH 8/9] fix tests
---
.../CodeGenHLSL/BasicFeatures/SplatCast.hlsl | 16 ++++++++--------
1 file changed, 8 insertions(+), 8 deletions(-)
diff --git a/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl
index 05359c1bce0ba35..2de68479179dd4c 100644
--- a/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl
+++ b/clang/test/CodeGenHLSL/BasicFeatures/SplatCast.hlsl
@@ -4,8 +4,8 @@
// CHECK-LABEL: define void {{.*}}call4
// CHECK: [[B:%.*]] = alloca [2 x i32], align 4
// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false)
-// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0
-// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 1
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 1
// CHECK-NEXT: store i32 3, ptr [[G1]], align 4
// CHECK-NEXT: store i32 3, ptr [[G2]], align 4
export void call4() {
@@ -20,8 +20,8 @@ export void call4() {
// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[B]], ptr align 4 {{.*}}, i32 8, i1 false)
// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
-// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0
-// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 1
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[B]], i32 0, i32 1
// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0
// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
// CHECK-NEXT: store i32 [[VL]], ptr [[G2]], align 4
@@ -58,8 +58,8 @@ struct S {
// CHECK: [[s:%.*]] = alloca %struct.S, align 4
// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
-// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0
-// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 1
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 1
// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0
// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float
@@ -75,8 +75,8 @@ export void call3() {
// CHECK-NEXT: [[s:%.*]] = alloca %struct.S, align 4
// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
-// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0
-// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 1
+// CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 0
+// CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 1
// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i64 0
// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float
>From 844ba82eb5dcfbd0105db2d4943266fa8d009c17 Mon Sep 17 00:00:00 2001
From: Sarah Spall <sarahspall at microsoft.com>
Date: Sat, 8 Feb 2025 09:07:05 -0800
Subject: [PATCH 9/9] add cast to cases
---
clang/lib/CodeGen/CGExpr.cpp | 1 +
clang/lib/CodeGen/CGExprAgg.cpp | 2 +-
2 files changed, 2 insertions(+), 1 deletion(-)
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 2bbc0791c65876f..545d8b11a6a47a9 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -5339,6 +5339,7 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) {
case CK_HLSLVectorTruncation:
case CK_HLSLArrayRValue:
case CK_HLSLElementwiseCast:
+ case CK_HLSLSplatCast:
return EmitUnsupportedLValue(E, "unexpected cast lvalue");
case CK_Dependent:
diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index 3330cd03628f75e..b7fe62687b074a0 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -1599,7 +1599,7 @@ static bool castPreservesZero(const CastExpr *CE) {
case CK_AtomicToNonAtomic:
case CK_HLSLVectorTruncation:
case CK_HLSLElementwiseCast:
- // TODO is this true for CK_HLSLSplatCast
+ case CK_HLSLSplatCast:
return true;
case CK_BaseToDerivedMemberPointer:
More information about the cfe-commits
mailing list