[clang] [HLSL][Matrix] Add support for ICK_HLSL_Matrix_Splat to add splat cast of scalars (PR #170885)
Farzon Lotfi via cfe-commits
cfe-commits at lists.llvm.org
Fri Dec 5 09:12:26 PST 2025
https://github.com/farzonl created https://github.com/llvm/llvm-project/pull/170885
fixes #168960
Adds `ICK_HLSL_Matrix_Splat` and hooks it up to `PerformImplicitConversion` and `IsMatrixConversion`. Map these to `CK_HLSLAggregateSplatCast`.
>From e4dd6308c4141fe86f0b8f9b2f4a39108c51c34c Mon Sep 17 00:00:00 2001
From: Farzon Lotfi <farzonlotfi at microsoft.com>
Date: Fri, 5 Dec 2025 11:56:00 -0500
Subject: [PATCH] [HLSL][Matrix] Add support for ICK_HLSL_Matrix_Splat to add
splat cast of scalars
fixes #168960
Adds `ICK_HLSL_Matrix_Splat` and hooks it up to `PerformImplicitConversion` and `IsMatrixConversion`. Map these to `CK_HLSLAggregateSplatCast`.
---
clang/include/clang/Sema/Overload.h | 3 +
clang/include/clang/Sema/Sema.h | 4 ++
clang/lib/Sema/SemaExpr.cpp | 33 +++++++++++
clang/lib/Sema/SemaExprCXX.cpp | 10 ++++
clang/lib/Sema/SemaOverload.cpp | 12 ++++
.../BasicFeatures/MatrixSplat.hlsl | 57 +++++++++++++++++++
.../MatrixElementOverloadResolution.hlsl | 12 +++-
.../BuiltinMatrix/MatrixSplatErrors.hlsl | 11 ++++
8 files changed, 139 insertions(+), 3 deletions(-)
create mode 100644 clang/test/CodeGenHLSL/BasicFeatures/MatrixSplat.hlsl
create mode 100644 clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixSplatErrors.hlsl
diff --git a/clang/include/clang/Sema/Overload.h b/clang/include/clang/Sema/Overload.h
index ab45328ee8ab7..cc9be00e9108c 100644
--- a/clang/include/clang/Sema/Overload.h
+++ b/clang/include/clang/Sema/Overload.h
@@ -207,6 +207,9 @@ class Sema;
// HLSL vector splat from scalar or boolean type.
ICK_HLSL_Vector_Splat,
+ /// HLSL matrix splat from scalar or boolean type.
+ ICK_HLSL_Matrix_Splat,
+
/// The number of conversion kinds
ICK_Num_Conversion_Kinds,
};
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 4a601a0eaf1b9..2a32bf8b257ad 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -7944,6 +7944,10 @@ class Sema final : public SemaBase {
/// implicit casts if necessary.
ExprResult prepareVectorSplat(QualType VectorTy, Expr *SplattedExpr);
+ /// Prepare `SplattedExpr` for a matrix splat operation, adding
+ /// implicit casts if necessary.
+ ExprResult prepareMatrixSplat(QualType MatrixTy, Expr *SplattedExpr);
+
// CheckExtVectorCast - check type constraints for extended vectors.
// Since vectors are an extension, there are no C standard reference for this.
// We allow casting between vectors and integer datatypes of the same size,
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index cfabd1b76c103..f5b6855b87c33 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -7806,6 +7806,39 @@ ExprResult Sema::prepareVectorSplat(QualType VectorTy, Expr *SplattedExpr) {
return ImpCastExprToType(SplattedExpr, DestElemTy, CK);
}
+ExprResult Sema::prepareMatrixSplat(QualType MatrixTy, Expr *SplattedExpr) {
+ QualType DestElemTy = MatrixTy->castAs<MatrixType>()->getElementType();
+
+ if (DestElemTy == SplattedExpr->getType())
+ return SplattedExpr;
+
+ assert(DestElemTy->isFloatingType() ||
+ DestElemTy->isIntegralOrEnumerationType());
+
+ CastKind CK;
+ if (SplattedExpr->getType()->isBooleanType()) {
+ // As with vectors, we want `true` to become -1 when splatting, and we
+ // need a two-step cast if the destination element type is floating.
+ if (DestElemTy->isFloatingType()) {
+ // Cast boolean to signed integral, then to floating.
+ ExprResult CastExprRes = ImpCastExprToType(SplattedExpr, Context.IntTy,
+ CK_BooleanToSignedIntegral);
+ SplattedExpr = CastExprRes.get();
+ CK = CK_IntegralToFloating;
+ } else {
+ CK = CK_BooleanToSignedIntegral;
+ }
+ } else {
+ ExprResult CastExprRes = SplattedExpr;
+ CK = PrepareScalarCast(CastExprRes, DestElemTy);
+ if (CastExprRes.isInvalid())
+ return ExprError();
+ SplattedExpr = CastExprRes.get();
+ }
+
+ return ImpCastExprToType(SplattedExpr, DestElemTy, CK);
+}
+
ExprResult Sema::CheckExtVectorCast(SourceRange R, QualType DestTy,
Expr *CastExpr, CastKind &Kind) {
assert(DestTy->isExtVectorType() && "Not an extended vector type!");
diff --git a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/SemaExprCXX.cpp
index 69719ebd1fc8c..e7af3579be69a 100644
--- a/clang/lib/Sema/SemaExprCXX.cpp
+++ b/clang/lib/Sema/SemaExprCXX.cpp
@@ -5198,6 +5198,7 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
case ICK_HLSL_Vector_Truncation:
case ICK_HLSL_Matrix_Truncation:
case ICK_HLSL_Vector_Splat:
+ case ICK_HLSL_Matrix_Splat:
llvm_unreachable("Improper second standard conversion");
}
@@ -5217,6 +5218,15 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
.get();
break;
}
+ case ICK_HLSL_Matrix_Splat: {
+ // Matrix splat from any arithmetic type to a matrix.
+ Expr *Elem = prepareMatrixSplat(ToType, From).get();
+ From =
+ ImpCastExprToType(Elem, ToType, CK_HLSLAggregateSplatCast, VK_PRValue,
+ /*BasePath=*/nullptr, CCK)
+ .get();
+ break;
+ }
case ICK_HLSL_Vector_Truncation: {
// Note: HLSL built-in vectors are ExtVectors. Since this truncates a
// vector to a smaller vector or to a scalar, this can only operate on
diff --git a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp
index 9a3a78164f0f8..bc3cfe7ef9a0c 100644
--- a/clang/lib/Sema/SemaOverload.cpp
+++ b/clang/lib/Sema/SemaOverload.cpp
@@ -165,6 +165,7 @@ ImplicitConversionRank clang::GetConversionRank(ImplicitConversionKind Kind) {
ICR_HLSL_Dimension_Reduction,
ICR_Conversion,
ICR_HLSL_Scalar_Widening,
+ ICR_HLSL_Scalar_Widening,
};
static_assert(std::size(Rank) == (int)ICK_Num_Conversion_Kinds);
return Rank[(int)Kind];
@@ -228,6 +229,7 @@ static const char *GetImplicitConversionName(ImplicitConversionKind Kind) {
"HLSL matrix truncation",
"Non-decaying array conversion",
"HLSL vector splat",
+ "HLSL matrix splat",
};
static_assert(std::size(Name) == (int)ICK_Num_Conversion_Kinds);
return Name[Kind];
@@ -2145,6 +2147,15 @@ static bool IsMatrixConversion(Sema &S, QualType FromType, QualType ToType,
return true;
return IsVectorOrMatrixElementConversion(S, FromElTy, ToElTy, ICK, From);
}
+
+ // Matrix splat from any arithmetic type to a matrix.
+ if (ToMatrixType && FromType->isArithmeticType()) {
+ ElConv = ICK_HLSL_Matrix_Splat;
+ QualType ToElTy = ToMatrixType->getElementType();
+ return IsVectorOrMatrixElementConversion(S, FromType, ToElTy, ICK, From);
+ ICK = ICK_HLSL_Matrix_Splat;
+ return true;
+ }
if (FromMatrixType && !ToMatrixType) {
ElConv = ICK_HLSL_Matrix_Truncation;
QualType FromElTy = FromMatrixType->getElementType();
@@ -6301,6 +6312,7 @@ static bool CheckConvertedConstantConversions(Sema &S,
case ICK_SVE_Vector_Conversion:
case ICK_RVV_Vector_Conversion:
case ICK_HLSL_Vector_Splat:
+ case ICK_HLSL_Matrix_Splat:
case ICK_Vector_Splat:
case ICK_Complex_Real:
case ICK_Block_Pointer_Conversion:
diff --git a/clang/test/CodeGenHLSL/BasicFeatures/MatrixSplat.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/MatrixSplat.hlsl
new file mode 100644
index 0000000000000..802c418f1dad5
--- /dev/null
+++ b/clang/test/CodeGenHLSL/BasicFeatures/MatrixSplat.hlsl
@@ -0,0 +1,57 @@
+// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 6
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.7-library -disable-llvm-passes -emit-llvm -finclude-default-header -o - %s | FileCheck %s
+
+// CHECK-LABEL: define hidden void @_Z13ConstantSplatv(
+// CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[M:%.*]] = alloca [16 x i32], align 4
+// CHECK-NEXT: store <16 x i32> splat (i32 1), ptr [[M]], align 4
+// CHECK-NEXT: ret void
+//
+void ConstantSplat() {
+ int4x4 M = 1;
+}
+
+// CHECK-LABEL: define hidden void @_Z18ConstantFloatSplatv(
+// CHECK-SAME: ) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[M:%.*]] = alloca [4 x float], align 4
+// CHECK-NEXT: store <4 x float> splat (float 3.250000e+00), ptr [[M]], align 4
+// CHECK-NEXT: ret void
+//
+void ConstantFloatSplat() {
+ float2x2 M = 3.25;
+}
+
+// CHECK-LABEL: define hidden void @_Z12DynamicSplatf(
+// CHECK-SAME: float noundef nofpclass(nan inf) [[VALUE:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[VALUE_ADDR:%.*]] = alloca float, align 4
+// CHECK-NEXT: [[M:%.*]] = alloca [9 x float], align 4
+// CHECK-NEXT: store float [[VALUE]], ptr [[VALUE_ADDR]], align 4
+// CHECK-NEXT: [[TMP0:%.*]] = load float, ptr [[VALUE_ADDR]], align 4
+// CHECK-NEXT: [[SPLAT_SPLATINSERT:%.*]] = insertelement <9 x float> poison, float [[TMP0]], i64 0
+// CHECK-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <9 x float> [[SPLAT_SPLATINSERT]], <9 x float> poison, <9 x i32> zeroinitializer
+// CHECK-NEXT: store <9 x float> [[SPLAT_SPLAT]], ptr [[M]], align 4
+// CHECK-NEXT: ret void
+//
+void DynamicSplat(float Value) {
+ float3x3 M = Value;
+}
+
+// CHECK-LABEL: define hidden void @_Z13CastThenSplatDv4_f(
+// CHECK-SAME: <4 x float> noundef nofpclass(nan inf) [[VALUE:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: [[VALUE_ADDR:%.*]] = alloca <4 x float>, align 16
+// CHECK-NEXT: [[M:%.*]] = alloca [9 x float], align 4
+// CHECK-NEXT: store <4 x float> [[VALUE]], ptr [[VALUE_ADDR]], align 16
+// CHECK-NEXT: [[TMP0:%.*]] = load <4 x float>, ptr [[VALUE_ADDR]], align 16
+// CHECK-NEXT: [[CAST_VTRUNC:%.*]] = extractelement <4 x float> [[TMP0]], i32 0
+// CHECK-NEXT: [[SPLAT_SPLATINSERT:%.*]] = insertelement <9 x float> poison, float [[CAST_VTRUNC]], i64 0
+// CHECK-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <9 x float> [[SPLAT_SPLATINSERT]], <9 x float> poison, <9 x i32> zeroinitializer
+// CHECK-NEXT: store <9 x float> [[SPLAT_SPLAT]], ptr [[M]], align 4
+// CHECK-NEXT: ret void
+//
+void CastThenSplat(float4 Value) {
+ float3x3 M = (float) Value;
+}
diff --git a/clang/test/SemaHLSL/MatrixElementOverloadResolution.hlsl b/clang/test/SemaHLSL/MatrixElementOverloadResolution.hlsl
index 04149e176edbd..51500a3bcc145 100644
--- a/clang/test/SemaHLSL/MatrixElementOverloadResolution.hlsl
+++ b/clang/test/SemaHLSL/MatrixElementOverloadResolution.hlsl
@@ -228,12 +228,14 @@ void fn2x2(float2x2) {}
void fn2x2IO(inout float2x2) {}
void fnI2x2IO(inout int2x2) {}
-void matOrVec(float4 F) {}
-void matOrVec(float2x2 F) {}
+void matOrVec(float4 F) {} // expected-note {{candidate function}}
+void matOrVec(float2x2 F) {} // expected-note {{candidate function}}
void matOrVec2(float3 F) {} // expected-note{{candidate function}}
void matOrVec2(float2x3 F) {} // expected-note{{candidate function}}
+void matOrVec3(float4x4 F) {}
+
export void Case8(float2x3 f23, float4x4 f44, float3x3 f33, float3x2 f32) {
int2x2 i22 = f23;
// expected-warning at -1{{implicit conversion truncates matrix: 'float2x3' (aka 'matrix<float, 2, 3>') to 'int2x2' (aka 'matrix<int, 2, 2>')}}
@@ -269,8 +271,12 @@ export void Case8(float2x3 f23, float4x4 f44, float3x3 f33, float3x2 f32) {
//CHECK-NEXT: ImplicitCastExpr {{.*}} 'float4x4':'matrix<float, 4, 4>' <LValueToRValue>
#ifdef ERROR
- matOrVec(2.0); // TODO: See #168960 this should be ambiguous once we implement ICK_HLSL_Matrix_Splat.
+ matOrVec(2.0); // expected-error {{call to 'matOrVec' is ambiguous}}
#endif
+ matOrVec3(3.14);
+ //CHECK: ImplicitCastExpr {{.*}} 'float4x4':'matrix<float, 4, 4>' <HLSLAggregateSplatCast>
+ //CHECK-NEXT: FloatingLiteral {{.*}} <col:13> 'float' 3.140000e+00
+
matOrVec2(f23);
//CHECK: DeclRefExpr {{.*}} 'void (float2x3)' lvalue Function {{.*}} 'matOrVec2' 'void (float2x3)'
//CHECK-NEXT: ImplicitCastExpr {{.*}} 'float2x3':'matrix<float, 2, 3>' <LValueToRValue>
diff --git a/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixSplatErrors.hlsl b/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixSplatErrors.hlsl
new file mode 100644
index 0000000000000..0c2e53d382180
--- /dev/null
+++ b/clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixSplatErrors.hlsl
@@ -0,0 +1,11 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -finclude-default-header -std=hlsl202x -verify %s
+
+void SplatOfVectortoMat(int4 V){
+ int2x2 M = V;
+ // expected-error at -1 {{cannot initialize a variable of type 'int2x2' (aka 'matrix<int, 2, 2>') with an lvalue of type 'int4' (aka 'vector<int, 4>')}}
+}
+
+void SplatOfMattoMat(int4x3 N){
+ int4x4 M = N;
+ // expected-error at -1 {{cannot initialize a variable of type 'matrix<[2 * ...], 4>' with an lvalue of type 'matrix<[2 * ...], 3>'}}
+}
More information about the cfe-commits
mailing list