[clang] [HLSL][Matrix] Add support for Matrix element and trunc Casts (PR #168915)
via cfe-commits
cfe-commits at lists.llvm.org
Thu Nov 20 09:32:46 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-clang-codegen
Author: Farzon Lotfi (farzonl)
<details>
<summary>Changes</summary>
fixes #<!-- -->168737
fixes #<!-- -->168755
This change fixes adds support for Matrix truncations via the ICK_HLSL_Matrix_Truncation enum. That ends up being most of the files changed.
It also allows Matrix as an HLSL Elementwise cast as long as the cast does not perform a shape transformation ie 3x2 to 2x3.
Tests for the new elementwise and truncation behavior were added. As well as sema tests to make sure we error n the shape transformation cast.
I am punting right now on the ConstExpr Matrix support. That will need to be addressed later. Will file a seperate issue for that if reviewers agree it can wait.
---
Patch is 38.51 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/168915.diff
19 Files Affected:
- (modified) clang/include/clang/AST/OperationKinds.def (+3)
- (modified) clang/include/clang/Sema/Overload.h (+3)
- (modified) clang/lib/AST/Expr.cpp (+1)
- (modified) clang/lib/AST/ExprConstant.cpp (+13)
- (modified) clang/lib/CIR/CodeGen/CIRGenExpr.cpp (+2)
- (modified) clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp (+1)
- (modified) clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp (+1)
- (modified) clang/lib/CodeGen/CGExpr.cpp (+1)
- (modified) clang/lib/CodeGen/CGExprAgg.cpp (+2-1)
- (modified) clang/lib/CodeGen/CGExprComplex.cpp (+1)
- (modified) clang/lib/CodeGen/CGExprConstant.cpp (+1)
- (modified) clang/lib/CodeGen/CGExprScalar.cpp (+34-1)
- (modified) clang/lib/Sema/SemaExprCXX.cpp (+16-6)
- (modified) clang/lib/Sema/SemaHLSL.cpp (+4-1)
- (modified) clang/lib/Sema/SemaOverload.cpp (+69-6)
- (modified) clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp (+1)
- (added) clang/test/CodeGenHLSL/BasicFeatures/MatrixElementTypeCast.hlsl (+186)
- (added) clang/test/CodeGenHLSL/BasicFeatures/MatrixTruncation.hlsl (+156)
- (added) clang/test/SemaHLSL/Types/BuiltinMatrix/MatrixCastErrors.hlsl (+21)
``````````diff
diff --git a/clang/include/clang/AST/OperationKinds.def b/clang/include/clang/AST/OperationKinds.def
index c2dca895e8411..8a13ad988403b 100644
--- a/clang/include/clang/AST/OperationKinds.def
+++ b/clang/include/clang/AST/OperationKinds.def
@@ -364,6 +364,9 @@ CAST_OPERATION(IntToOCLSampler)
// Truncate a vector type by dropping elements from the end (HLSL only).
CAST_OPERATION(HLSLVectorTruncation)
+// Truncate a matrix type by dropping elements from the end (HLSL only).
+CAST_OPERATION(HLSLMatrixTruncation)
+
// Non-decaying array RValue cast (HLSL only).
CAST_OPERATION(HLSLArrayRValue)
diff --git a/clang/include/clang/Sema/Overload.h b/clang/include/clang/Sema/Overload.h
index 59bbd0fbd9e95..1ad52cb9da517 100644
--- a/clang/include/clang/Sema/Overload.h
+++ b/clang/include/clang/Sema/Overload.h
@@ -198,6 +198,9 @@ class Sema;
/// HLSL vector truncation.
ICK_HLSL_Vector_Truncation,
+ /// HLSL Matrid truncation.
+ ICK_HLSL_Matrix_Truncation,
+
/// HLSL non-decaying array rvalue cast.
ICK_HLSL_Array_RValue,
diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp
index 1d914fa876759..159ea4867857d 100644
--- a/clang/lib/AST/Expr.cpp
+++ b/clang/lib/AST/Expr.cpp
@@ -1937,6 +1937,7 @@ bool CastExpr::CastConsistency() const {
case CK_FixedPointToBoolean:
case CK_HLSLArrayRValue:
case CK_HLSLVectorTruncation:
+ case CK_HLSLMatrixTruncation:
case CK_HLSLElementwiseCast:
case CK_HLSLAggregateSplatCast:
CheckNoBasePath:
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 74f6e3acb6b39..b7ea213679d2a 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -11773,6 +11773,10 @@ bool VectorExprEvaluator::VisitCastExpr(const CastExpr *E) {
Elements.push_back(Val.getVectorElt(I));
return Success(Elements, E);
}
+ case CK_HLSLMatrixTruncation: {
+ // TODO: support Expr Constant for Matrix Truncation
+ return Error(E);
+ }
case CK_HLSLAggregateSplatCast: {
APValue Val;
QualType ValTy;
@@ -18011,6 +18015,10 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
return Error(E);
return Success(Val.getVectorElt(0), E);
}
+ case CK_HLSLMatrixTruncation: {
+ // TODO: support Expr Constant for Matrix Truncation
+ return Error(E);
+ }
case CK_HLSLElementwiseCast: {
SmallVector<APValue> SrcVals;
SmallVector<QualType> SrcTypes;
@@ -18604,6 +18612,10 @@ bool FloatExprEvaluator::VisitCastExpr(const CastExpr *E) {
return Error(E);
return Success(Val.getVectorElt(0), E);
}
+ case CK_HLSLMatrixTruncation: {
+ // TODO: support Expr Constant for Matrix Truncation
+ return Error(E);
+ }
case CK_HLSLElementwiseCast: {
SmallVector<APValue> SrcVals;
SmallVector<QualType> SrcTypes;
@@ -18761,6 +18773,7 @@ bool ComplexExprEvaluator::VisitCastExpr(const CastExpr *E) {
case CK_IntegralToFixedPoint:
case CK_MatrixCast:
case CK_HLSLVectorTruncation:
+ case CK_HLSLMatrixTruncation:
case CK_HLSLElementwiseCast:
case CK_HLSLAggregateSplatCast:
llvm_unreachable("invalid cast kind for complex value");
diff --git a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp
index 8607558c1cf7d..abfbca16cd60b 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp
@@ -188,6 +188,7 @@ Address CIRGenFunction::emitPointerWithAlignment(const Expr *expr,
case CK_HLSLArrayRValue:
case CK_HLSLElementwiseCast:
case CK_HLSLVectorTruncation:
+ case CK_HLSLMatrixTruncation:
case CK_IntToOCLSampler:
case CK_IntegralCast:
case CK_IntegralComplexCast:
@@ -1279,6 +1280,7 @@ LValue CIRGenFunction::emitCastLValue(const CastExpr *e) {
case CK_IntegralToFixedPoint:
case CK_MatrixCast:
case CK_HLSLVectorTruncation:
+ case CK_HLSLMatrixTruncation:
case CK_HLSLArrayRValue:
case CK_HLSLElementwiseCast:
case CK_HLSLAggregateSplatCast:
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp b/clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp
index 9ed920085c8c6..fe06f8cc2c430 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprComplex.cpp
@@ -534,6 +534,7 @@ mlir::Value ComplexExprEmitter::emitCast(CastKind ck, Expr *op,
case CK_IntegralToFixedPoint:
case CK_MatrixCast:
case CK_HLSLVectorTruncation:
+ case CK_HLSLMatrixTruncation:
case CK_HLSLArrayRValue:
case CK_HLSLElementwiseCast:
case CK_HLSLAggregateSplatCast:
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp b/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp
index 6af87a0159f0a..7ce02f9b42af4 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp
@@ -1012,6 +1012,7 @@ class ConstExprEmitter
case CK_MatrixCast:
case CK_HLSLArrayRValue:
case CK_HLSLVectorTruncation:
+ case CK_HLSLMatrixTruncation:
case CK_HLSLElementwiseCast:
case CK_HLSLAggregateSplatCast:
return {};
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index f2451b16e78be..1737301c67021 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -5744,6 +5744,7 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) {
case CK_IntegralToFixedPoint:
case CK_MatrixCast:
case CK_HLSLVectorTruncation:
+ case CK_HLSLMatrixTruncation:
case CK_HLSLArrayRValue:
case CK_HLSLElementwiseCast:
case CK_HLSLAggregateSplatCast:
diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index 67b5f919d1b2a..7cc4d6c8f06f6 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -1036,7 +1036,7 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
case CK_ZeroToOCLOpaqueType:
case CK_MatrixCast:
case CK_HLSLVectorTruncation:
-
+ case CK_HLSLMatrixTruncation:
case CK_IntToOCLSampler:
case CK_FloatingToFixedPoint:
case CK_FixedPointToFloating:
@@ -1550,6 +1550,7 @@ static bool castPreservesZero(const CastExpr *CE) {
case CK_NonAtomicToAtomic:
case CK_AtomicToNonAtomic:
case CK_HLSLVectorTruncation:
+ case CK_HLSLMatrixTruncation:
case CK_HLSLElementwiseCast:
case CK_HLSLAggregateSplatCast:
return true;
diff --git a/clang/lib/CodeGen/CGExprComplex.cpp b/clang/lib/CodeGen/CGExprComplex.cpp
index f8a946a76554a..e6683d4c931b8 100644
--- a/clang/lib/CodeGen/CGExprComplex.cpp
+++ b/clang/lib/CodeGen/CGExprComplex.cpp
@@ -621,6 +621,7 @@ ComplexPairTy ComplexExprEmitter::EmitCast(CastKind CK, Expr *Op,
case CK_IntegralToFixedPoint:
case CK_MatrixCast:
case CK_HLSLVectorTruncation:
+ case CK_HLSLMatrixTruncation:
case CK_HLSLArrayRValue:
case CK_HLSLElementwiseCast:
case CK_HLSLAggregateSplatCast:
diff --git a/clang/lib/CodeGen/CGExprConstant.cpp b/clang/lib/CodeGen/CGExprConstant.cpp
index 6407afc3d9447..0eec4dba4824a 100644
--- a/clang/lib/CodeGen/CGExprConstant.cpp
+++ b/clang/lib/CodeGen/CGExprConstant.cpp
@@ -1333,6 +1333,7 @@ class ConstExprEmitter
case CK_ZeroToOCLOpaqueType:
case CK_MatrixCast:
case CK_HLSLVectorTruncation:
+ case CK_HLSLMatrixTruncation:
case CK_HLSLArrayRValue:
case CK_HLSLElementwiseCast:
case CK_HLSLAggregateSplatCast:
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 714192db1b15c..a9e2ebdffa59a 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -2422,9 +2422,27 @@ static Value *EmitHLSLElementwiseCast(CodeGenFunction &CGF, LValue SrcVal,
}
return V;
}
+ if (auto *MatTy = DestTy->getAs<ConstantMatrixType>()) {
+ assert(LoadList.size() >= MatTy->getNumElementsFlattened() &&
+ "Flattened type on RHS must have the same number or more elements "
+ "than vector on LHS.");
+ llvm::Value *V =
+ CGF.Builder.CreateLoad(CGF.CreateIRTemp(DestTy, "flatcast.tmp"));
+ // write to V.
+ for (unsigned I = 0, E = MatTy->getNumElementsFlattened(); I < E; I++) {
+ RValue RVal = CGF.EmitLoadOfLValue(LoadList[I], Loc);
+ assert(RVal.isScalar() &&
+ "All flattened source values should be scalars.");
+ llvm::Value *Cast =
+ CGF.EmitScalarConversion(RVal.getScalarVal(), LoadList[I].getType(),
+ MatTy->getElementType(), Loc);
+ V = CGF.Builder.CreateInsertElement(V, Cast, I);
+ }
+ return V;
+ }
// if its a builtin just do an extract element or load.
assert(DestTy->isBuiltinType() &&
- "Destination type must be a vector or builtin type.");
+ "Destination type must be a vector, matrix, or builtin type.");
RValue RVal = CGF.EmitLoadOfLValue(LoadList[0], Loc);
assert(RVal.isScalar() && "All flattened source values should be scalars.");
return CGF.EmitScalarConversion(RVal.getScalarVal(), LoadList[0].getType(),
@@ -2954,6 +2972,21 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
llvm::Value *Zero = llvm::Constant::getNullValue(CGF.SizeTy);
return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc");
}
+ case CK_HLSLMatrixTruncation: {
+ assert((DestTy->isMatrixType() || DestTy->isBuiltinType()) &&
+ "Destination type must be a matrix or builtin type.");
+ Value *Mat = Visit(E);
+ if (auto *MatTy = DestTy->getAs<ConstantMatrixType>()) {
+ SmallVector<int> Mask;
+ unsigned NumElts = MatTy->getNumElementsFlattened();
+ for (unsigned I = 0; I != NumElts; ++I)
+ Mask.push_back(I);
+
+ return Builder.CreateShuffleVector(Mat, Mask, "trunc");
+ }
+ llvm::Value *Zero = llvm::Constant::getNullValue(CGF.SizeTy);
+ return Builder.CreateExtractElement(Mat, Zero, "cast.mtrunc");
+ }
case CK_HLSLElementwiseCast: {
RValue RV = CGF.EmitAnyExpr(E);
SourceLocation Loc = CE->getExprLoc();
diff --git a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/SemaExprCXX.cpp
index dc7ed4e9a48bc..be3ac296f2597 100644
--- a/clang/lib/Sema/SemaExprCXX.cpp
+++ b/clang/lib/Sema/SemaExprCXX.cpp
@@ -5197,6 +5197,7 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
case ICK_Incompatible_Pointer_Conversion:
case ICK_HLSL_Array_RValue:
case ICK_HLSL_Vector_Truncation:
+ case ICK_HLSL_Matrix_Truncation:
case ICK_HLSL_Vector_Splat:
llvm_unreachable("Improper second standard conversion");
}
@@ -5204,12 +5205,10 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
if (SCS.Dimension != ICK_Identity) {
// If SCS.Element is not ICK_Identity the To and From types must be HLSL
// vectors or matrices.
-
- // TODO: Support HLSL matrices.
- assert((!From->getType()->isMatrixType() && !ToType->isMatrixType()) &&
- "Dimension conversion for matrix types is not implemented yet.");
- assert((ToType->isVectorType() || ToType->isBuiltinType()) &&
- "Dimension conversion output must be vector or scalar type.");
+ assert(
+ (ToType->isVectorType() || ToType->isConstantMatrixType() ||
+ ToType->isBuiltinType()) &&
+ "Dimension conversion output must be vector, matrix, or scalar type.");
switch (SCS.Dimension) {
case ICK_HLSL_Vector_Splat: {
// Vector splat from any arithmetic type to a vector.
@@ -5235,6 +5234,17 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
break;
}
+ case ICK_HLSL_Matrix_Truncation: {
+ auto *FromMat = From->getType()->castAs<ConstantMatrixType>();
+ QualType TruncTy = FromMat->getElementType();
+ if (auto *ToMat = ToType->getAs<ConstantMatrixType>())
+ TruncTy = Context.getConstantMatrixType(TruncTy, ToMat->getNumRows(),
+ ToMat->getNumColumns());
+ From = ImpCastExprToType(From, TruncTy, CK_HLSLMatrixTruncation,
+ From->getValueKind())
+ .get();
+ break;
+ }
case ICK_Identity:
default:
llvm_unreachable("Improper element standard conversion");
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 5555916c2536f..168bfc3da99e0 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -3728,7 +3728,6 @@ bool SemaHLSL::CanPerformAggregateSplatCast(Expr *Src, QualType DestTy) {
}
// Can we perform an HLSL Elementwise cast?
-// TODO: update this code when matrices are added; see issue #88060
bool SemaHLSL::CanPerformElementwiseCast(Expr *Src, QualType DestTy) {
// Don't handle casts where LHS and RHS are any combination of scalar/vector
@@ -3741,6 +3740,10 @@ bool SemaHLSL::CanPerformElementwiseCast(Expr *Src, QualType DestTy) {
(DestTy->isScalarType() || DestTy->isVectorType()))
return false;
+ if (SrcTy->isConstantMatrixType() &&
+ (DestTy->isScalarType() || DestTy->isConstantMatrixType()))
+ return false;
+
llvm::SmallVector<QualType> DestTypes;
BuildFlattenedTypeList(DestTy, DestTypes);
llvm::SmallVector<QualType> SrcTypes;
diff --git a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp
index 37f351174e3d0..f912c2431fc6f 100644
--- a/clang/lib/Sema/SemaOverload.cpp
+++ b/clang/lib/Sema/SemaOverload.cpp
@@ -162,6 +162,7 @@ ImplicitConversionRank clang::GetConversionRank(ImplicitConversionKind Kind) {
ICR_C_Conversion_Extension,
ICR_Conversion,
ICR_HLSL_Dimension_Reduction,
+ ICR_HLSL_Dimension_Reduction,
ICR_Conversion,
ICR_HLSL_Scalar_Widening,
};
@@ -224,6 +225,7 @@ static const char *GetImplicitConversionName(ImplicitConversionKind Kind) {
"Incompatible pointer conversion",
"Fixed point conversion",
"HLSL vector truncation",
+ "HLSL matrix truncation",
"Non-decaying array conversion",
"HLSL vector splat",
};
@@ -2046,9 +2048,10 @@ static bool IsFloatingPointConversion(Sema &S, QualType FromType,
return true;
}
-static bool IsVectorElementConversion(Sema &S, QualType FromType,
- QualType ToType,
- ImplicitConversionKind &ICK, Expr *From) {
+static bool IsVectorOrMatrixElementConversion(Sema &S, QualType FromType,
+ QualType ToType,
+ ImplicitConversionKind &ICK,
+ Expr *From) {
if (S.Context.hasSameUnqualifiedType(FromType, ToType))
return true;
@@ -2088,6 +2091,59 @@ static bool IsVectorElementConversion(Sema &S, QualType FromType,
return false;
}
+/// Determine whether the conversion from FromType to ToType is a valid
+/// matrix conversion.
+///
+/// \param ICK Will be set to the matrix conversion kind, if this is a matrix
+/// conversion.
+static bool IsMatrixConversion(Sema &S, QualType FromType, QualType ToType,
+ ImplicitConversionKind &ICK,
+ ImplicitConversionKind &ElConv, Expr *From,
+ bool InOverloadResolution, bool CStyle) {
+ // The non HLSL Matrix conversion rules are not clear.
+ if (!S.getLangOpts().HLSL)
+ return false;
+
+ auto *ToMatrixType = ToType->getAs<ConstantMatrixType>();
+ auto *FromMatrixType = FromType->getAs<ConstantMatrixType>();
+
+ // If both arguments are vectors, handle possible vector truncation and
+ // element conversion.
+ if (ToMatrixType && FromMatrixType) {
+ unsigned FromCols = FromMatrixType->getNumColumns();
+ unsigned ToCols = ToMatrixType->getNumColumns();
+ if (FromCols < ToCols)
+ return false;
+
+ unsigned FromRows = FromMatrixType->getNumRows();
+ unsigned ToRows = ToMatrixType->getNumRows();
+ if (FromRows < ToRows)
+ return false;
+
+ unsigned FromElts = FromMatrixType->getNumElementsFlattened();
+ unsigned ToElts = ToMatrixType->getNumElementsFlattened();
+ if (FromElts == ToElts)
+ ElConv = ICK_Identity;
+ else
+ ElConv = ICK_HLSL_Matrix_Truncation;
+
+ QualType FromElTy = FromMatrixType->getElementType();
+ QualType ToElTy = ToMatrixType->getElementType();
+ if (S.Context.hasSameUnqualifiedType(FromElTy, ToElTy))
+ return true;
+ return IsVectorOrMatrixElementConversion(S, FromElTy, ToElTy, ICK, From);
+ }
+ if (FromMatrixType && !ToMatrixType) {
+ ElConv = ICK_HLSL_Matrix_Truncation;
+ QualType FromElTy = FromMatrixType->getElementType();
+ if (S.Context.hasSameUnqualifiedType(FromElTy, ToType))
+ return true;
+ return IsVectorOrMatrixElementConversion(S, FromElTy, ToType, ICK, From);
+ }
+
+ return false;
+}
+
/// Determine whether the conversion from FromType to ToType is a valid
/// vector conversion.
///
@@ -2127,14 +2183,14 @@ static bool IsVectorConversion(Sema &S, QualType FromType, QualType ToType,
QualType ToElTy = ToExtType->getElementType();
if (S.Context.hasSameUnqualifiedType(FromElTy, ToElTy))
return true;
- return IsVectorElementConversion(S, FromElTy, ToElTy, ICK, From);
+ return IsVectorOrMatrixElementConversion(S, FromElTy, ToElTy, ICK, From);
}
if (FromExtType && !ToExtType) {
ElConv = ICK_HLSL_Vector_Truncation;
QualType FromElTy = FromExtType->getElementType();
if (S.Context.hasSameUnqualifiedType(FromElTy, ToType))
return true;
- return IsVectorElementConversion(S, FromElTy, ToType, ICK, From);
+ return IsVectorOrMatrixElementConversion(S, FromElTy, ToType, ICK, From);
}
// Fallthrough for the case where ToType is a vector and FromType is not.
}
@@ -2161,7 +2217,8 @@ static bool IsVectorConversion(Sema &S, QualType FromType, QualType ToType,
if (S.getLangOpts().HLSL) {
ElConv = ICK_HLSL_Vector_Splat;
QualType ToElTy = ToExtType->getElementType();
- return IsVectorElementConversion(S, FromType, ToElTy, ICK, From);
+ return IsVectorOrMatrixElementConversion(S, FromType, ToElTy, ICK,
+ From);
}
ICK = ICK_Vector_Splat;
return true;
@@ -2460,6 +2517,11 @@ static bool IsStandardConversion(Sema &S, Expr* From, QualType ToType,
SCS.Second = SecondICK;
SCS.Dimension = DimensionICK;
FromType = ToType.getUnqualifiedType();
+ } else if (IsMatrixConversion(S, FromType, ToType, SecondICK, DimensionICK,
+ From, InOverloadResolution, CStyle)) {
+ SCS.Second = SecondICK;
+ SCS.Dimension = DimensionICK;
+ FromType = ToType.getUnqualifiedType();
} else if (!S.getLangOpts().CPlusPlus &&
S.Context.typesAreCompatible(ToType, FromType)) {
// Compatible conversions (Clang extension for C function overloading)
@@ -6237,6 +6299,7 @@ static bool CheckConvertedConstantConversions(Sema &S,
case ICK_Incompatible_Pointer_Conversion:
case ICK_Fixed_Point_Conversion:
case ICK_HLSL_Vector_Truncation:
+ case ICK_HLSL_Matrix_Truncation:
return false;
case ICK_Lvalue_To_Rvalue:
diff --git a/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp b/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp
index 4ddf8fd5b4b0f..db27c06cd18a3 100644
--- a/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp
+++ b/clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp
@@ -560,6 +560,7 @@ void ExprEngine::VisitCast(const CastExpr *CastE, const Expr *Ex,
case CK_VectorSplat:
case CK_HLSLElementwiseCast:
case CK_HLSLAggregateSplatCast:
+ case CK_HLSLMatrixTruncation:
case CK_HLSLVectorTruncation: {
QualType resultType = CastE->getType();
if (CastE->isGLValue())
diff --git a/clang/test/CodeGenHLSL/BasicFeatures/MatrixElementTypeCast.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/MatrixElementTypeCast.hlsl
new file mode 100644
index 0000000000000..081b8013efcbc
--- /dev/null
+++ b/clang/test/CodeGenHLSL/BasicFeatures/MatrixElementTypeCast.hlsl
@@ -0,0 +1,186 @@
+// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 6
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -fnative-half-type -fnative-int16-type -o - %s | FileCheck %s
+
+
+// CHECK-LABEL: define hidden noundef <6 x i32> @_Z22elementwise_type_cast0u11matrix_typeILm3ELm2EfE(
+// CHECK-SAME: <6 x float> noundef nofpclass(nan inf) [[F32:%.*]]) #[[ATTR0:[0-9]+]...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/168915
More information about the cfe-commits
mailing list