[clang] Implement HLSL Flat casting (excluding splat cases) (PR #118842)

via cfe-commits cfe-commits at lists.llvm.org
Thu Dec 5 09:49:40 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-clang

@llvm/pr-subscribers-hlsl

Author: Sarah Spall (spall)

<details>
<summary>Changes</summary>

Implement HLSL Flat casting excluding splat cases
Partly closes #<!-- -->100609 and #<!-- -->100619 

---

Patch is 40.86 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/118842.diff


19 Files Affected:

- (modified) clang/include/clang/AST/OperationKinds.def (+3) 
- (modified) clang/include/clang/Sema/SemaHLSL.h (+2) 
- (modified) clang/lib/AST/Expr.cpp (+1) 
- (modified) clang/lib/AST/ExprConstant.cpp (+1) 
- (modified) clang/lib/CodeGen/CGExpr.cpp (+86) 
- (modified) clang/lib/CodeGen/CGExprAgg.cpp (+78-1) 
- (modified) clang/lib/CodeGen/CGExprComplex.cpp (+1) 
- (modified) clang/lib/CodeGen/CGExprConstant.cpp (+1) 
- (modified) clang/lib/CodeGen/CGExprScalar.cpp (+40) 
- (modified) clang/lib/CodeGen/CodeGenFunction.h (+8) 
- (modified) clang/lib/Edit/RewriteObjCFoundationAPI.cpp (+1) 
- (modified) clang/lib/Sema/Sema.cpp (+1) 
- (modified) clang/lib/Sema/SemaCast.cpp (+19-3) 
- (modified) clang/lib/Sema/SemaHLSL.cpp (+144) 
- (modified) clang/lib/StaticAnalyzer/Core/ExprEngineC.cpp (+1) 
- (added) clang/test/CodeGenHLSL/BasicFeatures/ArrayFlatCast.hlsl (+128) 
- (added) clang/test/CodeGenHLSL/BasicFeatures/StructFlatCast.hlsl (+124) 
- (added) clang/test/CodeGenHLSL/BasicFeatures/VectorFlatCast.hlsl (+81) 
- (modified) clang/test/SemaHLSL/BuiltIns/vector-constructors-erros.hlsl (-2) 


``````````diff
diff --git a/clang/include/clang/AST/OperationKinds.def b/clang/include/clang/AST/OperationKinds.def
index 8788b8ff0ef0a4..9323d4e861a734 100644
--- a/clang/include/clang/AST/OperationKinds.def
+++ b/clang/include/clang/AST/OperationKinds.def
@@ -367,6 +367,9 @@ CAST_OPERATION(HLSLVectorTruncation)
 // Non-decaying array RValue cast (HLSL only).
 CAST_OPERATION(HLSLArrayRValue)
 
+// Aggregate by Value cast (HLSL only).
+CAST_OPERATION(HLSLAggregateCast)
+
 //===- Binary Operations  -------------------------------------------------===//
 // Operators listed in order of precedence.
 // Note that additions to this should also update the StmtVisitor class,
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index ee685d95c96154..6bda1e8ce0ea5b 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -140,6 +140,8 @@ class SemaHLSL : public SemaBase {
   // Diagnose whether the input ID is uint/unit2/uint3 type.
   bool diagnoseInputIDType(QualType T, const ParsedAttr &AL);
 
+  bool CanPerformScalarCast(QualType SrcTy, QualType DestTy);
+  bool CanPerformAggregateCast(Expr *Src, QualType DestType);
   ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg);
 
   QualType getInoutParameterType(QualType Ty);
diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp
index a4fb4d5a1f2ec4..4764bc84ce498a 100644
--- a/clang/lib/AST/Expr.cpp
+++ b/clang/lib/AST/Expr.cpp
@@ -1942,6 +1942,7 @@ bool CastExpr::CastConsistency() const {
   case CK_FixedPointToBoolean:
   case CK_HLSLArrayRValue:
   case CK_HLSLVectorTruncation:
+  case CK_HLSLAggregateCast:
   CheckNoBasePath:
     assert(path_empty() && "Cast kind should not have a base path!");
     break;
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 6b5b95aee35522..b548cef41b7525 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -15733,6 +15733,7 @@ bool ComplexExprEvaluator::VisitCastExpr(const CastExpr *E) {
   case CK_IntegralToFixedPoint:
   case CK_MatrixCast:
   case CK_HLSLVectorTruncation:
+  case CK_HLSLAggregateCast:
     llvm_unreachable("invalid cast kind for complex value");
 
   case CK_LValueToRValue:
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 5fccc9cbb37ec1..6b9c437ef7e242 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -5320,6 +5320,7 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) {
   case CK_MatrixCast:
   case CK_HLSLVectorTruncation:
   case CK_HLSLArrayRValue:
+  case CK_HLSLAggregateCast:
     return EmitUnsupportedLValue(E, "unexpected cast lvalue");
 
   case CK_Dependent:
@@ -6358,3 +6359,88 @@ RValue CodeGenFunction::EmitPseudoObjectRValue(const PseudoObjectExpr *E,
 LValue CodeGenFunction::EmitPseudoObjectLValue(const PseudoObjectExpr *E) {
   return emitPseudoObjectExpr(*this, E, true, AggValueSlot::ignored()).LV;
 }
+
+llvm::Value *
+CodeGenFunction::PerformLoad(std::pair<Address, llvm::Value *> &GEP) {
+  Address GEPAddress = GEP.first;
+  llvm::Value *Idx = GEP.second;
+  llvm::Value *V = Builder.CreateLoad(GEPAddress, "load");
+  if (Idx) { // loading from a vector so perform an extract as well
+    return Builder.CreateExtractElement(V, Idx, "vec.load");
+  }
+  return V;
+}
+
+llvm::Value *
+CodeGenFunction::PerformStore(std::pair<Address, llvm::Value *> &GEP,
+                              llvm::Value *Val) {
+  Address GEPAddress = GEP.first;
+  llvm::Value *Idx = GEP.second;
+  if (Idx) {
+    llvm::Value *V = Builder.CreateLoad(GEPAddress, "load.for.insert");
+    return Builder.CreateInsertElement(V, Val, Idx);
+  } else {
+    return Builder.CreateStore(Val, GEPAddress);
+  }
+}
+
+void CodeGenFunction::FlattenAccessAndType(
+    Address Val, QualType SrcTy, SmallVector<llvm::Value *, 4> &IdxList,
+    SmallVector<std::pair<Address, llvm::Value *>, 16> &GEPList,
+    SmallVector<QualType> &FlatTypes) {
+  llvm::IntegerType *IdxTy = llvm::IntegerType::get(getLLVMContext(), 32);
+  if (const ConstantArrayType *CAT = dyn_cast<ConstantArrayType>(SrcTy)) {
+    uint64_t Size = CAT->getZExtSize();
+    for (unsigned i = 0; i < Size; i++) {
+      // flatten each member of the array
+      // add index of this element to index list
+      llvm::Value *Idx = llvm::ConstantInt::get(IdxTy, i);
+      IdxList.push_back(Idx);
+      // recur on this object
+      FlattenAccessAndType(Val, CAT->getElementType(), IdxList, GEPList,
+                           FlatTypes);
+      // remove index of this element from index list
+      IdxList.pop_back();
+    }
+  } else if (const RecordType *RT = SrcTy->getAs<RecordType>()) {
+    RecordDecl *Record = RT->getDecl();
+    const CGRecordLayout &RL = getTypes().getCGRecordLayout(Record);
+    // do I need to check if its a cxx record decl?
+
+    for (auto fieldIter = Record->field_begin(), fieldEnd = Record->field_end();
+         fieldIter != fieldEnd; ++fieldIter) {
+      // get the field number
+      unsigned FieldNum = RL.getLLVMFieldNo(*fieldIter);
+      // can we just do *fieldIter->getFieldIndex();
+      // add that index to the index list
+      llvm::Value *Idx = llvm::ConstantInt::get(IdxTy, FieldNum);
+      IdxList.push_back(Idx);
+      // recur on the field
+      FlattenAccessAndType(Val, fieldIter->getType(), IdxList, GEPList,
+                           FlatTypes);
+      // remove index of this element from index list
+      IdxList.pop_back();
+    }
+  } else if (const VectorType *VT = SrcTy->getAs<VectorType>()) {
+    llvm::Type *VTy = ConvertTypeForMem(SrcTy);
+    CharUnits Align = getContext().getTypeAlignInChars(SrcTy);
+    Address GEP =
+        Builder.CreateInBoundsGEP(Val, IdxList, VTy, Align, "vector.gep");
+    for (unsigned i = 0; i < VT->getNumElements(); i++) {
+      // add index to the list
+      llvm::Value *Idx = llvm::ConstantInt::get(IdxTy, i);
+      // create gep. no need to recur since its always a scalar
+      // gep on vector is not recommended so combine gep with extract/insert
+      GEPList.push_back({GEP, Idx});
+      FlatTypes.push_back(VT->getElementType());
+    }
+  } else { // should be a scalar should we assert or check?
+    // create a gep
+    llvm::Type *Ty = ConvertTypeForMem(SrcTy);
+    CharUnits Align = getContext().getTypeAlignInChars(SrcTy);
+    Address GEP = Builder.CreateInBoundsGEP(Val, IdxList, Ty, Align, "gep");
+    GEPList.push_back({GEP, NULL});
+    FlatTypes.push_back(SrcTy);
+  }
+  // target extension types?
+}
diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index 2ad6587089f101..e3b47de958ce55 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -491,6 +491,65 @@ static bool isTrivialFiller(Expr *E) {
   return false;
 }
 
+// emit a flat cast where the RHS is a scalar, including vector
+static void EmitHLSLScalarFlatCast(CodeGenFunction &CGF, Address DestVal,
+                                   QualType DestTy, llvm::Value *SrcVal,
+                                   QualType SrcTy, SourceLocation Loc) {
+  // Flatten our destination
+  SmallVector<QualType> DestTypes; // Flattened type
+  SmallVector<llvm::Value *, 4> IdxList;
+  SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
+  // ^^ Flattened accesses to DestVal we want to store into
+  CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList, DestTypes);
+
+  if (const VectorType *VT = SrcTy->getAs<VectorType>()) {
+    SrcTy = VT->getElementType();
+    assert(StoreGEPList.size() <= VT->getNumElements() &&
+           "Cannot perform HLSL flat cast when vector source \
+           object has less elements than flattened destination \
+           object.");
+    for (unsigned i = 0; i < StoreGEPList.size(); i++) {
+      llvm::Value *Load =
+          CGF.Builder.CreateExtractElement(SrcVal, i, "vec.load");
+      llvm::Value *Cast =
+          CGF.EmitScalarConversion(Load, SrcTy, DestTypes[i], Loc);
+      CGF.PerformStore(StoreGEPList[i], Cast);
+    }
+    return;
+  }
+  llvm_unreachable("HLSL Flat cast doesn't handle splatting.");
+}
+
+// emit a flat cast where the RHS is an aggregate
+static void EmitHLSLAggregateFlatCast(CodeGenFunction &CGF, Address DestVal,
+                                      QualType DestTy, Address SrcVal,
+                                      QualType SrcTy, SourceLocation Loc) {
+  // Flatten our destination
+  SmallVector<QualType> DestTypes; // Flattened type
+  SmallVector<llvm::Value *, 4> IdxList;
+  SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
+  // ^^ Flattened accesses to DestVal we want to store into
+  CGF.FlattenAccessAndType(DestVal, DestTy, IdxList, StoreGEPList, DestTypes);
+  // Flatten our src
+  SmallVector<QualType> SrcTypes; // Flattened type
+  SmallVector<std::pair<Address, llvm::Value *>, 16> LoadGEPList;
+  // ^^ Flattened accesses to SrcVal we want to load from
+  IdxList.clear();
+  CGF.FlattenAccessAndType(SrcVal, SrcTy, IdxList, LoadGEPList, SrcTypes);
+
+  assert(StoreGEPList.size() <= LoadGEPList.size() &&
+         "Cannot perform HLSL flat cast when flattened source object \
+          has less elements than flattened destination object.");
+  // apply casts to what we load from LoadGEPList
+  // and store result in Dest
+  for (unsigned i = 0; i < StoreGEPList.size(); i++) {
+    llvm::Value *Load = CGF.PerformLoad(LoadGEPList[i]);
+    llvm::Value *Cast =
+        CGF.EmitScalarConversion(Load, SrcTypes[i], DestTypes[i], Loc);
+    CGF.PerformStore(StoreGEPList[i], Cast);
+  }
+}
+
 /// Emit initialization of an array from an initializer list. ExprToVisit must
 /// be either an InitListEpxr a CXXParenInitListExpr.
 void AggExprEmitter::EmitArrayInit(Address DestPtr, llvm::ArrayType *AType,
@@ -890,7 +949,25 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
   case CK_HLSLArrayRValue:
     Visit(E->getSubExpr());
     break;
-
+  case CK_HLSLAggregateCast: {
+    Expr *Src = E->getSubExpr();
+    QualType SrcTy = Src->getType();
+    RValue RV = CGF.EmitAnyExpr(Src);
+    QualType DestTy = E->getType();
+    Address DestVal = Dest.getAddress();
+    SourceLocation Loc = E->getExprLoc();
+
+    if (RV.isScalar()) {
+      llvm::Value *SrcVal = RV.getScalarVal();
+      EmitHLSLScalarFlatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
+    } else { // RHS is an aggregate
+      assert(RV.isAggregate() &&
+             "Can't perform HLSL Aggregate cast on a complex type.");
+      Address SrcVal = RV.getAggregateAddress();
+      EmitHLSLAggregateFlatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
+    }
+    break;
+  }
   case CK_NoOp:
   case CK_UserDefinedConversion:
   case CK_ConstructorConversion:
diff --git a/clang/lib/CodeGen/CGExprComplex.cpp b/clang/lib/CodeGen/CGExprComplex.cpp
index ac31dff11b585e..05680d36aa2bd7 100644
--- a/clang/lib/CodeGen/CGExprComplex.cpp
+++ b/clang/lib/CodeGen/CGExprComplex.cpp
@@ -610,6 +610,7 @@ ComplexPairTy ComplexExprEmitter::EmitCast(CastKind CK, Expr *Op,
   case CK_MatrixCast:
   case CK_HLSLVectorTruncation:
   case CK_HLSLArrayRValue:
+  case CK_HLSLAggregateCast:
     llvm_unreachable("invalid cast kind for complex value");
 
   case CK_FloatingRealToComplex:
diff --git a/clang/lib/CodeGen/CGExprConstant.cpp b/clang/lib/CodeGen/CGExprConstant.cpp
index 655fc3dc954c81..6d15bc9058e450 100644
--- a/clang/lib/CodeGen/CGExprConstant.cpp
+++ b/clang/lib/CodeGen/CGExprConstant.cpp
@@ -1335,6 +1335,7 @@ class ConstExprEmitter
     case CK_MatrixCast:
     case CK_HLSLVectorTruncation:
     case CK_HLSLArrayRValue:
+    case CK_HLSLAggregateCast:
       return nullptr;
     }
     llvm_unreachable("Invalid CastKind");
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 4ae8a2b22b1bba..3809e3b1db3494 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -2262,6 +2262,36 @@ bool CodeGenFunction::ShouldNullCheckClassCastValue(const CastExpr *CE) {
   return true;
 }
 
+// RHS is an aggregate type
+static Value *EmitHLSLAggregateFlatCast(CodeGenFunction &CGF, Address RHSVal,
+                                        QualType RHSTy, QualType LHSTy,
+                                        SourceLocation Loc) {
+  SmallVector<llvm::Value *, 4> IdxList;
+  SmallVector<std::pair<Address, llvm::Value *>, 16> LoadGEPList;
+  SmallVector<QualType> SrcTypes; // Flattened type
+  CGF.FlattenAccessAndType(RHSVal, RHSTy, IdxList, LoadGEPList, SrcTypes);
+  // LHS is either a vector or a builtin?
+  // if its a vector create a temp alloca to store into and return that
+  if (auto *VecTy = LHSTy->getAs<VectorType>()) {
+    llvm::Value *V =
+        CGF.Builder.CreateLoad(CGF.CreateIRTemp(LHSTy, "flatcast.tmp"));
+    // write to V.
+    for (unsigned i = 0; i < VecTy->getNumElements(); i++) {
+      llvm::Value *Load = CGF.PerformLoad(LoadGEPList[i]);
+      llvm::Value *Cast = CGF.EmitScalarConversion(
+          Load, SrcTypes[i], VecTy->getElementType(), Loc);
+      V = CGF.Builder.CreateInsertElement(V, Cast, i);
+    }
+    return V;
+  }
+  // i its a builtin just do an extract element or load.
+  assert(LHSTy->isBuiltinType() &&
+         "Destination type must be a vector or builtin type.");
+  // TODO add asserts about things being long enough
+  return CGF.EmitScalarConversion(CGF.PerformLoad(LoadGEPList[0]), LHSTy,
+                                  SrcTypes[0], Loc);
+}
+
 // VisitCastExpr - Emit code for an explicit or implicit cast.  Implicit casts
 // have to handle a more broad range of conversions than explicit casts, as they
 // handle things like function to ptr-to-function decay etc.
@@ -2752,7 +2782,17 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
     llvm::Value *Zero = llvm::Constant::getNullValue(CGF.SizeTy);
     return Builder.CreateExtractElement(Vec, Zero, "cast.vtrunc");
   }
+  case CK_HLSLAggregateCast: {
+    RValue RV = CGF.EmitAnyExpr(E);
+    SourceLocation Loc = CE->getExprLoc();
+    QualType SrcTy = E->getType();
 
+    if (RV.isAggregate()) { // RHS is an aggregate
+      Address SrcVal = RV.getAggregateAddress();
+      return EmitHLSLAggregateFlatCast(CGF, SrcVal, SrcTy, DestTy, Loc);
+    }
+    llvm_unreachable("Not a valid HLSL Flat Cast.");
+  }
   } // end of switch
 
   llvm_unreachable("unknown scalar cast");
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index eaea0d8a08ac06..873dd781eb2e7d 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4359,6 +4359,14 @@ class CodeGenFunction : public CodeGenTypeCache {
                                 AggValueSlot slot = AggValueSlot::ignored());
   LValue EmitPseudoObjectLValue(const PseudoObjectExpr *e);
 
+  llvm::Value *PerformLoad(std::pair<Address, llvm::Value *> &GEP);
+  llvm::Value *PerformStore(std::pair<Address, llvm::Value *> &GEP,
+                            llvm::Value *Val);
+  void FlattenAccessAndType(
+      Address Val, QualType SrcTy, SmallVector<llvm::Value *, 4> &IdxList,
+      SmallVector<std::pair<Address, llvm::Value *>, 16> &GEPList,
+      SmallVector<QualType> &FlatTypes);
+
   llvm::Value *EmitIvarOffset(const ObjCInterfaceDecl *Interface,
                               const ObjCIvarDecl *Ivar);
   llvm::Value *EmitIvarOffsetAsPointerDiff(const ObjCInterfaceDecl *Interface,
diff --git a/clang/lib/Edit/RewriteObjCFoundationAPI.cpp b/clang/lib/Edit/RewriteObjCFoundationAPI.cpp
index 81797c8c4dc75a..63308319a78d1c 100644
--- a/clang/lib/Edit/RewriteObjCFoundationAPI.cpp
+++ b/clang/lib/Edit/RewriteObjCFoundationAPI.cpp
@@ -1085,6 +1085,7 @@ static bool rewriteToNumericBoxedExpression(const ObjCMessageExpr *Msg,
       llvm_unreachable("OpenCL-specific cast in Objective-C?");
 
     case CK_HLSLVectorTruncation:
+    case CK_HLSLAggregateCast:
       llvm_unreachable("HLSL-specific cast in Objective-C?");
       break;
 
diff --git a/clang/lib/Sema/Sema.cpp b/clang/lib/Sema/Sema.cpp
index d6517511d7db4d..2f0528d6ab5ce1 100644
--- a/clang/lib/Sema/Sema.cpp
+++ b/clang/lib/Sema/Sema.cpp
@@ -707,6 +707,7 @@ ExprResult Sema::ImpCastExprToType(Expr *E, QualType Ty,
     case CK_ToVoid:
     case CK_NonAtomicToAtomic:
     case CK_HLSLArrayRValue:
+    case CK_HLSLAggregateCast:
       break;
     }
   }
diff --git a/clang/lib/Sema/SemaCast.cpp b/clang/lib/Sema/SemaCast.cpp
index f98857f852b5af..0bd7fc91aee18f 100644
--- a/clang/lib/Sema/SemaCast.cpp
+++ b/clang/lib/Sema/SemaCast.cpp
@@ -23,6 +23,7 @@
 #include "clang/Basic/TargetInfo.h"
 #include "clang/Lex/Preprocessor.h"
 #include "clang/Sema/Initialization.h"
+#include "clang/Sema/SemaHLSL.h"
 #include "clang/Sema/SemaObjC.h"
 #include "clang/Sema/SemaRISCV.h"
 #include "llvm/ADT/SmallVector.h"
@@ -2768,6 +2769,24 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
     return;
   }
 
+  CheckedConversionKind CCK = FunctionalStyle
+                                  ? CheckedConversionKind::FunctionalCast
+                                  : CheckedConversionKind::CStyleCast;
+  // todo what else should i be doing lvalue to rvalue cast for?
+  // why dont they do it for records below?
+  // This case should not trigger on regular vector splat
+  // Or vector cast or vector truncation.
+  QualType SrcTy = SrcExpr.get()->getType();
+  if (Self.getLangOpts().HLSL &&
+      Self.HLSL().CanPerformAggregateCast(SrcExpr.get(), DestType)) {
+    if (SrcTy->isConstantArrayType())
+      SrcExpr = Self.ImpCastExprToType(
+          SrcExpr.get(), Self.Context.getArrayParameterType(SrcTy),
+          CK_HLSLArrayRValue, VK_PRValue, nullptr, CCK);
+    Kind = CK_HLSLAggregateCast;
+    return;
+  }
+
   if (ValueKind == VK_PRValue && !DestType->isRecordType() &&
       !isPlaceholder(BuiltinType::Overload)) {
     SrcExpr = Self.DefaultFunctionArrayLvalueConversion(SrcExpr.get());
@@ -2820,9 +2839,6 @@ void CastOperation::CheckCXXCStyleCast(bool FunctionalStyle,
   if (isValidCast(tcr))
     Kind = CK_NoOp;
 
-  CheckedConversionKind CCK = FunctionalStyle
-                                  ? CheckedConversionKind::FunctionalCast
-                                  : CheckedConversionKind::CStyleCast;
   if (tcr == TC_NotApplicable) {
     tcr = TryAddressSpaceCast(Self, SrcExpr, DestType, /*CStyle*/ true, msg,
                               Kind);
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 88db3e12541193..5c7af8056063ad 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2412,6 +2412,150 @@ bool SemaHLSL::CheckCompatibleParameterABI(FunctionDecl *New,
   return HadError;
 }
 
+// Follows PerformScalarCast
+bool SemaHLSL::CanPerformScalarCast(QualType SrcTy, QualType DestTy) {
+
+  if (SemaRef.getASTContext().hasSameUnqualifiedType(SrcTy, DestTy))
+    return true;
+
+  switch (Type::ScalarTypeKind SrcKind = SrcTy->getScalarTypeKind()) {
+  case Type::STK_MemberPointer:
+    return false;
+
+  case Type::STK_CPointer:
+  case Type::STK_BlockPointer:
+  case Type::STK_ObjCObjectPointer:
+    switch (DestTy->getScalarTypeKind()) {
+    case Type::STK_CPointer:
+    case Type::STK_BlockPointer:
+    case Type::STK_ObjCObjectPointer:
+    case Type::STK_Bool:
+    case Type::STK_Integral:
+      return true;
+    case Type::STK_Floating:
+    case Type::STK_FloatingComplex:
+    case Type::STK_IntegralComplex:
+    case Type::STK_MemberPointer:
+      return false;
+    case Type::STK_FixedPoint:
+      llvm_unreachable("HLSL doesn't have fixed point types.");
+    }
+    llvm_unreachable("Should have returned before this");
+
+  case Type::STK_FixedPoint:
+    llvm_unreachable("HLSL doesn't have fixed point types.");
+
+  case Type::STK_Bool: // casting from bool is like casting from an integer
+  case Type::STK_Integral:
+    switch (DestTy->getScalarTypeKind()) {
+    case Type::STK_CPointer:
+    case Type::STK_ObjCObjectPointer:
+    case Type::STK_BlockPointer:
+    case Type::STK_Bool:
+    case Type::STK_Integral:
+    case Type::STK_Floating:
+    case Type::STK_IntegralComplex:
+    case Type::STK_FloatingComplex:
+      return true;
+    case Type::STK_FixedPoint:
+      llvm_unreachable("HLSL doesn't have fixed point types.");
+    case Type::STK_MemberPointer:
+      return false;
+    }
+    llvm_unreachable("Should have returned before this");
+
+  case Type::STK_Floating:
+    switch (DestTy->getScalarT...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/118842


More information about the cfe-commits mailing list