[clang] [HLSL] Add support for elementwise and aggregate splat casting struct types with bitfields (PR #161263)

Sarah Spall via cfe-commits cfe-commits at lists.llvm.org
Thu Oct 2 11:49:47 PDT 2025


https://github.com/spall updated https://github.com/llvm/llvm-project/pull/161263

>From 34d95372ceaab6919f816a8f305abd2db698e818 Mon Sep 17 00:00:00 2001
From: Sarah Spall <sarahspall at microsoft.com>
Date: Fri, 12 Sep 2025 13:44:19 -0700
Subject: [PATCH 1/2] flatten LValues instead of addresses to reuse existing
 code to enable handling bitfields; update tests.

---
 clang/lib/CodeGen/CGExpr.cpp                  | 102 +++++++-----
 clang/lib/CodeGen/CGExprAgg.cpp               | 146 +++++++-----------
 clang/lib/CodeGen/CGExprScalar.cpp            |  48 +++---
 clang/lib/CodeGen/CodeGenFunction.h           |   6 +-
 clang/lib/Sema/SemaHLSL.cpp                   |   6 -
 .../BasicFeatures/AggregateSplatCast.hlsl     |  38 +++++
 .../BasicFeatures/ArrayElementwiseCast.hlsl   |  46 +++++-
 .../BasicFeatures/StructElementwiseCast.hlsl  | 121 ++++++++++++++-
 .../BasicFeatures/VectorElementwiseCast.hlsl  |  42 +++++
 .../Language/AggregateSplatCast-errors.hlsl   |   6 -
 .../Language/ElementwiseCast-errors.hlsl      |  21 ---
 11 files changed, 384 insertions(+), 198 deletions(-)

diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index e6e4947882544..04e2b64d2bd7c 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -6784,29 +6784,26 @@ LValue CodeGenFunction::EmitPseudoObjectLValue(const PseudoObjectExpr *E) {
   return emitPseudoObjectExpr(*this, E, true, AggValueSlot::ignored()).LV;
 }
 
-void CodeGenFunction::FlattenAccessAndType(
-    Address Addr, QualType AddrType,
-    SmallVectorImpl<std::pair<Address, llvm::Value *>> &AccessList,
-    SmallVectorImpl<QualType> &FlatTypes) {
-  // WorkList is list of type we are processing + the Index List to access
-  // the field of that type in Addr for use in a GEP
-  llvm::SmallVector<std::pair<QualType, llvm::SmallVector<llvm::Value *, 4>>,
-                    16>
+void CodeGenFunction::FlattenAccessAndTypeLValue(
+    LValue Val, SmallVectorImpl<LValue> &AccessList) {
+
+  llvm::SmallVector<
+      std::tuple<LValue, QualType, llvm::SmallVector<llvm::Value *, 4>>, 16>
       WorkList;
   llvm::IntegerType *IdxTy = llvm::IntegerType::get(getLLVMContext(), 32);
-  // Addr should be a pointer so we need to 'dereference' it
-  WorkList.push_back({AddrType, {llvm::ConstantInt::get(IdxTy, 0)}});
+  WorkList.push_back({Val, Val.getType(), {llvm::ConstantInt::get(IdxTy, 0)}});
 
   while (!WorkList.empty()) {
-    auto [T, IdxList] = WorkList.pop_back_val();
+    auto [LVal, T, IdxList] = WorkList.pop_back_val();
     T = T.getCanonicalType().getUnqualifiedType();
     assert(!isa<MatrixType>(T) && "Matrix types not yet supported in HLSL");
+
     if (const auto *CAT = dyn_cast<ConstantArrayType>(T)) {
       uint64_t Size = CAT->getZExtSize();
       for (int64_t I = Size - 1; I > -1; I--) {
         llvm::SmallVector<llvm::Value *, 4> IdxListCopy = IdxList;
         IdxListCopy.push_back(llvm::ConstantInt::get(IdxTy, I));
-        WorkList.emplace_back(CAT->getElementType(), IdxListCopy);
+        WorkList.push_back({LVal, CAT->getElementType(), IdxListCopy});
       }
     } else if (const auto *RT = dyn_cast<RecordType>(T)) {
       const RecordDecl *Record = RT->getOriginalDecl()->getDefinitionOrSelf();
@@ -6814,44 +6811,77 @@ void CodeGenFunction::FlattenAccessAndType(
 
       const CXXRecordDecl *CXXD = dyn_cast<CXXRecordDecl>(Record);
 
-      llvm::SmallVector<QualType, 16> FieldTypes;
+      llvm::SmallVector<
+          std::tuple<LValue, QualType, llvm::SmallVector<llvm::Value *, 4>>, 16>
+          ReverseList;
       if (CXXD && CXXD->isStandardLayout())
         Record = CXXD->getStandardLayoutBaseWithFields();
 
       // deal with potential base classes
       if (CXXD && !CXXD->isStandardLayout()) {
-        for (auto &Base : CXXD->bases())
-          FieldTypes.push_back(Base.getType());
+        if (CXXD->getNumBases() > 0) {
+          assert(CXXD->getNumBases() == 1 &&
+                 "HLSL doesn't support multiple inheritance.");
+          auto Base = CXXD->bases_begin();
+          llvm::SmallVector<llvm::Value *, 4> IdxListCopy = IdxList;
+          IdxListCopy.push_back(llvm::ConstantInt::get(
+              IdxTy, 0)); // base struct should be at index zero
+          ReverseList.insert(ReverseList.end(),
+                             {LVal, Base->getType(), IdxListCopy});
+        }
       }
 
-      for (auto *FD : Record->fields())
-        FieldTypes.push_back(FD->getType());
+      const CGRecordLayout &Layout = CGM.getTypes().getCGRecordLayout(Record);
 
-      for (int64_t I = FieldTypes.size() - 1; I > -1; I--) {
-        llvm::SmallVector<llvm::Value *, 4> IdxListCopy = IdxList;
-        IdxListCopy.push_back(llvm::ConstantInt::get(IdxTy, I));
-        WorkList.insert(WorkList.end(), {FieldTypes[I], IdxListCopy});
+      llvm::Type *LLVMT = ConvertTypeForMem(T);
+      CharUnits Align = getContext().getTypeAlignInChars(T);
+      LValue RLValue;
+      bool createdGEP = false;
+      for (auto *FD : Record->fields()) {
+        if (FD->isBitField()) {
+          if (FD->isUnnamedBitField())
+            continue;
+          if (!createdGEP) {
+            createdGEP = true;
+            Address GEP = Builder.CreateInBoundsGEP(LVal.getAddress(), IdxList,
+                                                    LLVMT, Align, "gep");
+            RLValue = MakeAddrLValue(GEP, T);
+          }
+          LValue FieldLVal = EmitLValueForField(RLValue, FD, true);
+          ReverseList.insert(ReverseList.end(), {FieldLVal, FD->getType(), {}});
+        } else {
+          llvm::SmallVector<llvm::Value *, 4> IdxListCopy = IdxList;
+          IdxListCopy.push_back(
+              llvm::ConstantInt::get(IdxTy, Layout.getLLVMFieldNo(FD)));
+          ReverseList.insert(ReverseList.end(),
+                             {LVal, FD->getType(), IdxListCopy});
+        }
       }
+
+      std::reverse(ReverseList.begin(), ReverseList.end());
+      llvm::append_range(WorkList, ReverseList);
     } else if (const auto *VT = dyn_cast<VectorType>(T)) {
       llvm::Type *LLVMT = ConvertTypeForMem(T);
       CharUnits Align = getContext().getTypeAlignInChars(T);
-      Address GEP =
-          Builder.CreateInBoundsGEP(Addr, IdxList, LLVMT, Align, "vector.gep");
+      Address GEP = Builder.CreateInBoundsGEP(LVal.getAddress(), IdxList, LLVMT,
+                                              Align, "vector.gep");
+      LValue Base = MakeAddrLValue(GEP, T);
       for (unsigned I = 0, E = VT->getNumElements(); I < E; I++) {
-        llvm::Value *Idx = llvm::ConstantInt::get(IdxTy, I);
-        // gep on vector fields is not recommended so combine gep with
-        // extract/insert
-        AccessList.emplace_back(GEP, Idx);
-        FlatTypes.push_back(VT->getElementType());
+        llvm::Constant *Idx = llvm::ConstantInt::get(IdxTy, I);
+        LValue LV =
+            LValue::MakeVectorElt(Base.getAddress(), Idx, VT->getElementType(),
+                                  Base.getBaseInfo(), TBAAAccessInfo());
+        AccessList.emplace_back(LV);
       }
-    } else {
-      // a scalar/builtin type
-      llvm::Type *LLVMT = ConvertTypeForMem(T);
-      CharUnits Align = getContext().getTypeAlignInChars(T);
-      Address GEP =
-          Builder.CreateInBoundsGEP(Addr, IdxList, LLVMT, Align, "gep");
-      AccessList.emplace_back(GEP, nullptr);
-      FlatTypes.push_back(T);
+    } else { // a scalar/builtin type
+      if (!IdxList.empty()) {
+        llvm::Type *LLVMT = ConvertTypeForMem(T);
+        CharUnits Align = getContext().getTypeAlignInChars(T);
+        Address GEP = Builder.CreateInBoundsGEP(LVal.getAddress(), IdxList,
+                                                LLVMT, Align, "gep");
+        AccessList.emplace_back(MakeAddrLValue(GEP, T));
+      } else // must be a bitfield we already created an lvalue for
+        AccessList.emplace_back(LVal);
     }
   }
 }
diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index b8150a24d45fc..07b9aebe0bbe3 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -488,100 +488,62 @@ static bool isTrivialFiller(Expr *E) {
   return false;
 }
 
-static void EmitHLSLAggregateSplatCast(CodeGenFunction &CGF, Address DestVal,
-                                       QualType DestTy, llvm::Value *SrcVal,
-                                       QualType SrcTy, SourceLocation Loc) {
+// emit an elementwise cast where the RHS is a scalar or vector
+// or emit an aggregate splat cast
+static void EmitHLSLScalarElementwiseAndSplatCasts(CodeGenFunction &CGF,
+                                                   LValue DestVal,
+                                                   llvm::Value *SrcVal,
+                                                   QualType SrcTy,
+                                                   SourceLocation Loc) {
   // Flatten our destination
-  SmallVector<QualType> DestTypes; // Flattened type
-  SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
-  // ^^ Flattened accesses to DestVal we want to store into
-  CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes);
-
-  assert(SrcTy->isScalarType() && "Invalid HLSL Aggregate splat cast.");
-  for (unsigned I = 0, Size = StoreGEPList.size(); I < Size; ++I) {
-    llvm::Value *Cast =
-        CGF.EmitScalarConversion(SrcVal, SrcTy, DestTypes[I], Loc);
-
-    // store back
-    llvm::Value *Idx = StoreGEPList[I].second;
-    if (Idx) {
-      llvm::Value *V =
-          CGF.Builder.CreateLoad(StoreGEPList[I].first, "load.for.insert");
-      Cast = CGF.Builder.CreateInsertElement(V, Cast, Idx);
-    }
-    CGF.Builder.CreateStore(Cast, StoreGEPList[I].first);
-  }
-}
-
-// emit a flat cast where the RHS is a scalar, including vector
-static void EmitHLSLScalarFlatCast(CodeGenFunction &CGF, Address DestVal,
-                                   QualType DestTy, llvm::Value *SrcVal,
-                                   QualType SrcTy, SourceLocation Loc) {
-  // Flatten our destination
-  SmallVector<QualType, 16> DestTypes; // Flattened type
-  SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
-  // ^^ Flattened accesses to DestVal we want to store into
-  CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes);
-
-  assert(SrcTy->isVectorType() && "HLSL Flat cast doesn't handle splatting.");
-  const VectorType *VT = SrcTy->getAs<VectorType>();
-  SrcTy = VT->getElementType();
-  assert(StoreGEPList.size() <= VT->getNumElements() &&
-         "Cannot perform HLSL flat cast when vector source \
-         object has less elements than flattened destination \
-         object.");
-  for (unsigned I = 0, Size = StoreGEPList.size(); I < Size; I++) {
-    llvm::Value *Load = CGF.Builder.CreateExtractElement(SrcVal, I, "vec.load");
+  SmallVector<LValue, 16> StoreList;
+  CGF.FlattenAccessAndTypeLValue(DestVal, StoreList);
+
+  bool isVector = false;
+  if (auto *VT = SrcTy->getAs<VectorType>()) {
+    isVector = true;
+    SrcTy = VT->getElementType();
+    assert(StoreList.size() <= VT->getNumElements() &&
+           "Cannot perform HLSL flat cast when vector source \
+           object has less elements than flattened destination \
+           object.");
+  }
+
+  for (unsigned I = 0, Size = StoreList.size(); I < Size; I++) {
+    LValue DestLVal = StoreList[I];
+    llvm::Value *Load =
+        isVector ? CGF.Builder.CreateExtractElement(SrcVal, I, "vec.load")
+                 : SrcVal;
     llvm::Value *Cast =
-        CGF.EmitScalarConversion(Load, SrcTy, DestTypes[I], Loc);
-
-    // store back
-    llvm::Value *Idx = StoreGEPList[I].second;
-    if (Idx) {
-      llvm::Value *V =
-          CGF.Builder.CreateLoad(StoreGEPList[I].first, "load.for.insert");
-      Cast = CGF.Builder.CreateInsertElement(V, Cast, Idx);
-    }
-    CGF.Builder.CreateStore(Cast, StoreGEPList[I].first);
+        CGF.EmitScalarConversion(Load, SrcTy, DestLVal.getType(), Loc);
+    CGF.EmitStoreThroughLValue(RValue::get(Cast), DestLVal);
   }
 }
 
 // emit a flat cast where the RHS is an aggregate
-static void EmitHLSLElementwiseCast(CodeGenFunction &CGF, Address DestVal,
-                                    QualType DestTy, Address SrcVal,
-                                    QualType SrcTy, SourceLocation Loc) {
+static void EmitHLSLElementwiseCast(CodeGenFunction &CGF, LValue DestVal,
+                                    LValue SrcVal, SourceLocation Loc) {
   // Flatten our destination
-  SmallVector<QualType, 16> DestTypes; // Flattened type
-  SmallVector<std::pair<Address, llvm::Value *>, 16> StoreGEPList;
-  // ^^ Flattened accesses to DestVal we want to store into
-  CGF.FlattenAccessAndType(DestVal, DestTy, StoreGEPList, DestTypes);
+  SmallVector<LValue, 16> StoreList;
+  CGF.FlattenAccessAndTypeLValue(DestVal, StoreList);
   // Flatten our src
-  SmallVector<QualType, 16> SrcTypes; // Flattened type
-  SmallVector<std::pair<Address, llvm::Value *>, 16> LoadGEPList;
-  // ^^ Flattened accesses to SrcVal we want to load from
-  CGF.FlattenAccessAndType(SrcVal, SrcTy, LoadGEPList, SrcTypes);
+  SmallVector<LValue, 16> LoadList;
+  CGF.FlattenAccessAndTypeLValue(SrcVal, LoadList);
 
-  assert(StoreGEPList.size() <= LoadGEPList.size() &&
-         "Cannot perform HLSL flat cast when flattened source object \
+  assert(StoreList.size() <= LoadList.size() &&
+         "Cannot perform HLSL elementwise cast when flattened source object \
           has less elements than flattened destination object.");
-  // apply casts to what we load from LoadGEPList
+  // apply casts to what we load from LoadList
   // and store result in Dest
-  for (unsigned I = 0, E = StoreGEPList.size(); I < E; I++) {
-    llvm::Value *Idx = LoadGEPList[I].second;
-    llvm::Value *Load = CGF.Builder.CreateLoad(LoadGEPList[I].first, "load");
-    Load =
-        Idx ? CGF.Builder.CreateExtractElement(Load, Idx, "vec.extract") : Load;
-    llvm::Value *Cast =
-        CGF.EmitScalarConversion(Load, SrcTypes[I], DestTypes[I], Loc);
-
-    // store back
-    Idx = StoreGEPList[I].second;
-    if (Idx) {
-      llvm::Value *V =
-          CGF.Builder.CreateLoad(StoreGEPList[I].first, "load.for.insert");
-      Cast = CGF.Builder.CreateInsertElement(V, Cast, Idx);
-    }
-    CGF.Builder.CreateStore(Cast, StoreGEPList[I].first);
+  for (unsigned I = 0, E = StoreList.size(); I < E; I++) {
+    LValue DestLVal = StoreList[I];
+    LValue SrcLVal = LoadList[I];
+    RValue RVal = CGF.EmitLoadOfLValue(SrcLVal, Loc);
+    assert(RVal.isScalar() && "All flattened source values should be scalars");
+    llvm::Value *Val = RVal.getScalarVal();
+    llvm::Value *Cast = CGF.EmitScalarConversion(Val, SrcLVal.getType(),
+                                                 DestLVal.getType(), Loc);
+    CGF.EmitStoreThroughLValue(RValue::get(Cast), DestLVal);
   }
 }
 
@@ -988,31 +950,33 @@ void AggExprEmitter::VisitCastExpr(CastExpr *E) {
     Expr *Src = E->getSubExpr();
     QualType SrcTy = Src->getType();
     RValue RV = CGF.EmitAnyExpr(Src);
-    QualType DestTy = E->getType();
-    Address DestVal = Dest.getAddress();
+    LValue DestLVal = CGF.MakeAddrLValue(Dest.getAddress(), E->getType());
     SourceLocation Loc = E->getExprLoc();
 
-    assert(RV.isScalar() && "RHS of HLSL splat cast must be a scalar.");
+    assert(RV.isScalar() && SrcTy->isScalarType() &&
+           "RHS of HLSL splat cast must be a scalar.");
     llvm::Value *SrcVal = RV.getScalarVal();
-    EmitHLSLAggregateSplatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
+    EmitHLSLScalarElementwiseAndSplatCasts(CGF, DestLVal, SrcVal, SrcTy, Loc);
     break;
   }
   case CK_HLSLElementwiseCast: {
     Expr *Src = E->getSubExpr();
     QualType SrcTy = Src->getType();
     RValue RV = CGF.EmitAnyExpr(Src);
-    QualType DestTy = E->getType();
-    Address DestVal = Dest.getAddress();
+    LValue DestLVal = CGF.MakeAddrLValue(Dest.getAddress(), E->getType());
     SourceLocation Loc = E->getExprLoc();
 
     if (RV.isScalar()) {
       llvm::Value *SrcVal = RV.getScalarVal();
-      EmitHLSLScalarFlatCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
+      assert(SrcTy->isVectorType() &&
+             "HLSL Elementwise cast doesn't handle splatting.");
+      EmitHLSLScalarElementwiseAndSplatCasts(CGF, DestLVal, SrcVal, SrcTy, Loc);
     } else {
       assert(RV.isAggregate() &&
              "Can't perform HLSL Aggregate cast on a complex type.");
       Address SrcVal = RV.getAggregateAddress();
-      EmitHLSLElementwiseCast(CGF, DestVal, DestTy, SrcVal, SrcTy, Loc);
+      EmitHLSLElementwiseCast(CGF, DestLVal, CGF.MakeAddrLValue(SrcVal, SrcTy),
+                              Loc);
     }
     break;
   }
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 4fa25c5d66669..00ebf41b315bb 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -2392,39 +2392,36 @@ bool CodeGenFunction::ShouldNullCheckClassCastValue(const CastExpr *CE) {
 }
 
 // RHS is an aggregate type
-static Value *EmitHLSLElementwiseCast(CodeGenFunction &CGF, Address RHSVal,
-                                      QualType RHSTy, QualType LHSTy,
-                                      SourceLocation Loc) {
-  SmallVector<std::pair<Address, llvm::Value *>, 16> LoadGEPList;
-  SmallVector<QualType, 16> SrcTypes; // Flattened type
-  CGF.FlattenAccessAndType(RHSVal, RHSTy, LoadGEPList, SrcTypes);
-  // LHS is either a vector or a builtin?
+static Value *EmitHLSLElementwiseCast(CodeGenFunction &CGF, LValue SrcVal,
+                                      QualType DestTy, SourceLocation Loc) {
+  SmallVector<LValue, 16> LoadList;
+  CGF.FlattenAccessAndTypeLValue(SrcVal, LoadList);
+  // Dest is either a vector or a builtin?
   // if its a vector create a temp alloca to store into and return that
-  if (auto *VecTy = LHSTy->getAs<VectorType>()) {
-    assert(SrcTypes.size() >= VecTy->getNumElements() &&
+  if (auto *VecTy = DestTy->getAs<VectorType>()) {
+    assert(LoadList.size() >= VecTy->getNumElements() &&
            "Flattened type on RHS must have more elements than vector on LHS.");
     llvm::Value *V =
-        CGF.Builder.CreateLoad(CGF.CreateIRTemp(LHSTy, "flatcast.tmp"));
+        CGF.Builder.CreateLoad(CGF.CreateIRTemp(DestTy, "flatcast.tmp"));
     // write to V.
     for (unsigned I = 0, E = VecTy->getNumElements(); I < E; I++) {
-      llvm::Value *Load = CGF.Builder.CreateLoad(LoadGEPList[I].first, "load");
-      llvm::Value *Idx = LoadGEPList[I].second;
-      Load = Idx ? CGF.Builder.CreateExtractElement(Load, Idx, "vec.extract")
-                 : Load;
-      llvm::Value *Cast = CGF.EmitScalarConversion(
-          Load, SrcTypes[I], VecTy->getElementType(), Loc);
+      RValue RVal = CGF.EmitLoadOfLValue(LoadList[I], Loc);
+      assert(RVal.isScalar() &&
+             "All flattened source values should be scalars.");
+      llvm::Value *Cast =
+          CGF.EmitScalarConversion(RVal.getScalarVal(), LoadList[I].getType(),
+                                   VecTy->getElementType(), Loc);
       V = CGF.Builder.CreateInsertElement(V, Cast, I);
     }
     return V;
   }
-  // i its a builtin just do an extract element or load.
-  assert(LHSTy->isBuiltinType() &&
+  // if its a builtin just do an extract element or load.
+  assert(DestTy->isBuiltinType() &&
          "Destination type must be a vector or builtin type.");
-  llvm::Value *Load = CGF.Builder.CreateLoad(LoadGEPList[0].first, "load");
-  llvm::Value *Idx = LoadGEPList[0].second;
-  Load =
-      Idx ? CGF.Builder.CreateExtractElement(Load, Idx, "vec.extract") : Load;
-  return CGF.EmitScalarConversion(Load, LHSTy, SrcTypes[0], Loc);
+  RValue RVal = CGF.EmitLoadOfLValue(LoadList[0], Loc);
+  assert(RVal.isScalar() && "All flattened source values should be scalars.");
+  return CGF.EmitScalarConversion(RVal.getScalarVal(), LoadList[0].getType(),
+                                  DestTy, Loc);
 }
 
 // VisitCastExpr - Emit code for an explicit or implicit cast.  Implicit casts
@@ -2949,12 +2946,11 @@ Value *ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
   case CK_HLSLElementwiseCast: {
     RValue RV = CGF.EmitAnyExpr(E);
     SourceLocation Loc = CE->getExprLoc();
-    QualType SrcTy = E->getType();
 
     assert(RV.isAggregate() && "Not a valid HLSL Elementwise Cast.");
     // RHS is an aggregate
-    Address SrcVal = RV.getAggregateAddress();
-    return EmitHLSLElementwiseCast(CGF, SrcVal, SrcTy, DestTy, Loc);
+    LValue SrcVal = CGF.MakeAddrLValue(RV.getAggregateAddress(), E->getType());
+    return EmitHLSLElementwiseCast(CGF, SrcVal, DestTy, Loc);
   }
   } // end of switch
 
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 727487b46054f..a36aaca38b0e9 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4463,10 +4463,8 @@ class CodeGenFunction : public CodeGenTypeCache {
                                 AggValueSlot slot = AggValueSlot::ignored());
   LValue EmitPseudoObjectLValue(const PseudoObjectExpr *e);
 
-  void FlattenAccessAndType(
-      Address Addr, QualType AddrTy,
-      SmallVectorImpl<std::pair<Address, llvm::Value *>> &AccessList,
-      SmallVectorImpl<QualType> &FlatTypes);
+  void FlattenAccessAndTypeLValue(LValue LVal,
+                                  SmallVectorImpl<LValue> &AccessList);
 
   llvm::Value *EmitIvarOffset(const ObjCInterfaceDecl *Interface,
                               const ObjCIvarDecl *Ivar);
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index e17f9b9e8d758..3a48c9fcc7e5b 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -3557,9 +3557,6 @@ bool SemaHLSL::CanPerformAggregateSplatCast(Expr *Src, QualType DestTy) {
   if (SrcVecTy)
     SrcTy = SrcVecTy->getElementType();
 
-  if (ContainsBitField(DestTy))
-    return false;
-
   llvm::SmallVector<QualType> DestTypes;
   BuildFlattenedTypeList(DestTy, DestTypes);
 
@@ -3586,9 +3583,6 @@ bool SemaHLSL::CanPerformElementwiseCast(Expr *Src, QualType DestTy) {
       (DestTy->isScalarType() || DestTy->isVectorType()))
     return false;
 
-  if (ContainsBitField(DestTy) || ContainsBitField(SrcTy))
-    return false;
-
   llvm::SmallVector<QualType> DestTypes;
   BuildFlattenedTypeList(DestTy, DestTypes);
   llvm::SmallVector<QualType> SrcTypes;
diff --git a/clang/test/CodeGenHLSL/BasicFeatures/AggregateSplatCast.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/AggregateSplatCast.hlsl
index 512fcd435191a..63b2b50dd9885 100644
--- a/clang/test/CodeGenHLSL/BasicFeatures/AggregateSplatCast.hlsl
+++ b/clang/test/CodeGenHLSL/BasicFeatures/AggregateSplatCast.hlsl
@@ -85,3 +85,41 @@ export void call5() {
   int1 A = {1};
   S s = (S)A;
 }
+
+struct BFields {
+  double D;
+  int E: 15;
+  int : 8;
+  float F;
+};
+
+struct Derived : BFields {
+  int G;
+};
+
+// derived struct with bitfields splat from scalar
+// CHECK-LABEL: call6
+// CHECK: [[AAddr:%.*]] = alloca i32, align 4
+// CHECK-NEXT: [[D:%.*]] = alloca %struct.Derived, align 1
+// CHECK-NEXT: store i32 %A, ptr [[AAddr]], align 4
+// CHECK-NEXT: [[B:%.*]] = load i32, ptr [[AAddr]], align 4
+// CHECK-NEXT: [[Gep:%.*]] = getelementptr inbounds %struct.Derived, ptr [[D]], i32 0, i32 0
+// CHECK-NEXT: [[E:%.*]] = getelementptr inbounds nuw %struct.BFields, ptr [[Gep]], i32 0, i32 1
+// CHECK-NEXT: [[Gep1:%.*]] = getelementptr inbounds %struct.Derived, ptr [[D]], i32 0, i32 0, i32 0
+// CHECK-NEXT: [[Gep2:%.*]] = getelementptr inbounds %struct.Derived, ptr [[D]], i32 0, i32 0, i32 2
+// CHECK-NEXT: [[Gep3:%.*]] = getelementptr inbounds %struct.Derived, ptr [[D]], i32 0, i32 1
+// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[B]] to double
+// CHECK-NEXT: store double [[C]], ptr [[Gep1]], align 8
+// CHECK-NEXT: [[D:%.*]] = trunc i32 [[B]] to i24
+// CHECK-NEXT: [[BFL:%.*]] = load i24, ptr [[E]], align 1
+// CHECK-NEXT: [[BFV:%.*]] = and i24 [[D]], 32767
+// CHECK-NEXT: [[BFC:%.*]] = and i24 [[BFL]], -32768
+// CHECK-NEXT: [[BFS:%.*]] = or i24 [[BFC]], [[BFV]]
+// CHECK-NEXT: store i24 [[BFS]], ptr [[E]], align 1
+// CHECK-NEXT: [[C4:%.*]] = sitofp i32 [[B]] to float
+// CHECK-NEXT: store float [[C4]], ptr [[Gep2]], align 4
+// CHECK-NEXT: store i32 [[B]], ptr [[Gep3]], align 4
+// CHECK-NEXT: ret void
+export void call6(int A) {
+  Derived D = (Derived)A;
+}
diff --git a/clang/test/CodeGenHLSL/BasicFeatures/ArrayElementwiseCast.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/ArrayElementwiseCast.hlsl
index ac02ddf5765ed..5f2182e27285e 100644
--- a/clang/test/CodeGenHLSL/BasicFeatures/ArrayElementwiseCast.hlsl
+++ b/clang/test/CodeGenHLSL/BasicFeatures/ArrayElementwiseCast.hlsl
@@ -10,7 +10,8 @@
 // CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds [2 x i32], ptr [[Tmp]], i32 0, i32 0
 // CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds [2 x i32], ptr [[Tmp]], i32 0, i32 1
 // CHECK-NEXT: [[L:%.*]] = load i32, ptr [[G1]], align 4
-// CHECK-NEXT: store i32 [[L]], ptr [[B]], align 4
+// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[L]] to float
+// CHECK-NEXT: store float [[C]], ptr [[B]], align 4
 export void call0() {
   int A[2] = {0,1};
   float B = (float)A;
@@ -141,3 +142,46 @@ export void call7() {
   int A[1] = {1};
   A = (int[1])s;
 }
+
+struct BFields {
+  double D;
+  int E: 15;
+  int : 8;
+  float F;
+};
+
+struct Derived : BFields {
+  int G;
+};
+
+// flatten from a derived struct with bitfields
+// CHECK-LABEL: call8
+// CHECK: [[A:%.*]] = alloca [4 x i32], align 4
+// CHECK-NEXT: [[Tmp:%.*]] = alloca %struct.Derived, align 1
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 1 [[Tmp]], ptr align 1 %D, i32 19, i1 false)
+// CHECK-NEXT: [[Gep:%.*]] = getelementptr inbounds [4 x i32], ptr [[A]], i32 0, i32 0
+// CHECK-NEXT: [[Gep1:%.*]] = getelementptr inbounds [4 x i32], ptr [[A]], i32 0, i32 1
+// CHECK-NEXT: [[Gep2:%.*]] = getelementptr inbounds [4 x i32], ptr [[A]], i32 0, i32 2
+// CHECK-NEXT: [[Gep3:%.*]] = getelementptr inbounds [4 x i32], ptr [[A]], i32 0, i32 3
+// CHECK-NEXT: [[Gep4:%.*]] = getelementptr inbounds %struct.Derived, ptr [[Tmp]], i32 0, i32 0
+// CHECK-NEXT: [[E:%.*]] = getelementptr inbounds nuw %struct.BFields, ptr [[Gep4]], i32 0, i32 1
+// CHECK-NEXT: [[Gep5:%.*]] = getelementptr inbounds %struct.Derived, ptr [[Tmp]], i32 0, i32 0, i32 0
+// CHECK-NEXT: [[Gep6:%.*]] = getelementptr inbounds %struct.Derived, ptr [[Tmp]], i32 0, i32 0, i32 2
+// CHECK-NEXT: [[Gep7:%.*]] = getelementptr inbounds %struct.Derived, ptr [[Tmp]], i32 0, i32 1
+// CHECK-NEXT: [[Z:%.*]] = load double, ptr [[Gep5]], align 8
+// CHECK-NEXT: [[C:%.*]] = fptosi double [[Z]] to i32
+// CHECK-NEXT: store i32 [[C]], ptr [[Gep]], align 4
+// CHECK-NEXT: [[BFL:%.*]] = load i24, ptr [[E]], align 1
+// CHECK-NEXT: [[BFShl:%.*]] = shl i24 [[BFL]], 9
+// CHECK-NEXT: [[BFAshr:%.*]] = ashr i24 [[BFShl]], 9
+// CHECK-NEXT: [[BFC:%.*]] = sext i24 [[BFAshr]] to i32
+// CHECK-NEXT: store i32 [[BFC]], ptr [[Gep1]], align 4
+// CHECK-NEXT: [[Y:%.*]] = load float, ptr [[Gep6]], align 4
+// CHECK-NEXT: [[C8:%.*]] = fptosi float [[Y]] to i32
+// CHECK-NEXT: store i32 [[C8]], ptr [[Gep2]], align 4
+// CHECK-NEXT: [[X:%.*]] = load i32, ptr [[Gep7]], align 4
+// CHECK-NEXT: store i32 [[X]], ptr [[Gep3]], align 4
+// CHECK-NEXT: ret void
+export void call8(Derived D) {
+  int A[4] = (int[4])D;  
+}
diff --git a/clang/test/CodeGenHLSL/BasicFeatures/StructElementwiseCast.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/StructElementwiseCast.hlsl
index 81b9f5b28cc7e..39928e87c4df9 100644
--- a/clang/test/CodeGenHLSL/BasicFeatures/StructElementwiseCast.hlsl
+++ b/clang/test/CodeGenHLSL/BasicFeatures/StructElementwiseCast.hlsl
@@ -127,14 +127,121 @@ struct T {
 // CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 0
 // CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 1
 // CHECK-NEXT: [[G3:%.*]] = getelementptr inbounds %struct.T, ptr [[Tmp]], i32 0, i32 0
-// CHECK-NEXT: %gep3 = getelementptr inbounds %struct.T, ptr %agg-temp, i32 0, i32 1
-// CHECK-NEXT: %gep4 = getelementptr inbounds %struct.T, ptr %agg-temp, i32 0, i32 2
-// CHECK-NEXT: %load = load i32, ptr %gep2, align 4
-// CHECK-NEXT: store i32 %load, ptr %gep, align 4
-// CHECK-NEXT: %load5 = load i32, ptr %gep3, align 4
-// CHECK-NEXT: %conv = sitofp i32 %load5 to float
-// CHECK-NEXT: store float %conv, ptr %gep1, align 4
+// CHECK-NEXT: [[G4:%.*]] = getelementptr inbounds %struct.T, ptr %agg-temp, i32 0, i32 1
+// CHECK-NEXT: [[G5:%.*]] = getelementptr inbounds %struct.T, ptr %agg-temp, i32 0, i32 2
+// CHECK-NEXT: [[L1:%.*]] = load i32, ptr [[G3]], align 4
+// CHECK-NEXT: store i32 [[L1]], ptr [[G1]], align 4
+// CHECK-NEXT: [[L2:%.*]] = load i32, ptr [[G4]], align 4
+// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[L2]] to float
+// CHECK-NEXT: store float [[C]], ptr [[G2]], align 4
 export void call8() {
   T t = {1,2,3};
   S s = (S)t;
 }
+
+struct BFields {
+  double D;
+  int E: 15;
+  int : 8;
+  float F;
+};
+
+struct Derived : BFields {
+  int G;
+};
+
+// Derived Struct truncate to scalar
+// CHECK-LABEL: call9
+// CHECK: [[D2:%.*]] = alloca double, align 8
+// CHECK-NEXT: [[Tmp:%.*]] = alloca %struct.Derived, align 1
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 1 [[Tmp]], ptr align 1 %D, i32 19, i1 false)
+// CHECK-NEXT: [[Gep:%.*]] = getelementptr inbounds %struct.Derived, ptr [[Tmp]], i32 0, i32 0
+// CHECK-NEXT: [[E:%.*]] = getelementptr inbounds nuw %struct.BFields, ptr [[Gep]], i32 0, i32 1
+// CHECK-NEXT: [[Gep1:%.*]] = getelementptr inbounds %struct.Derived, ptr [[Tmp]], i32 0, i32 0, i32 0
+// CHECK-NEXT: [[Gep2:%.*]] = getelementptr inbounds %struct.Derived, ptr [[Tmp]], i32 0, i32 0, i32 2
+// CHECK-NEXT: [[Gep3:%.*]] = getelementptr inbounds %struct.Derived, ptr [[Tmp]], i32 0, i32 1
+// CHECK-NEXT: [[A:%.*]] = load double, ptr [[Gep1]], align 8
+// CHECK-NEXT: store double [[A]], ptr [[D2]], align 8
+// CHECK-NEXT: ret void
+export void call9(Derived D) {
+  double D2 = (double)D;
+}
+
+// Derived struct from vector
+// CHECK-LABEL: call10
+// CHECK: [[IAddr:%.*]] = alloca <4 x i32>, align 16
+// CHECK-NEXT: [[D:%.*]] = alloca %struct.Derived, align 1
+// CHECK-NEXT: store <4 x i32> %I, ptr [[IAddr]], align 16
+// CHECK-NEXT: [[A:%.*]] = load <4 x i32>, ptr [[IAddr]], align 16
+// CHECK-NEXT: [[Gep:%.*]] = getelementptr inbounds %struct.Derived, ptr [[D]], i32 0, i32 0
+// CHECK-NEXT: [[E:%.*]] = getelementptr inbounds nuw %struct.BFields, ptr [[Gep]], i32 0, i32 1
+// CHECK-NEXT: [[Gep1:%.*]] = getelementptr inbounds %struct.Derived, ptr [[D]], i32 0, i32 0, i32 0
+// CHECK-NEXT: [[Gep2:%.*]] = getelementptr inbounds %struct.Derived, ptr [[D]], i32 0, i32 0, i32 2
+// CHECK-NEXT: [[Gep3:%.*]] = getelementptr inbounds %struct.Derived, ptr [[D]], i32 0, i32 1
+// CHECK-NEXT: [[VL:%.*]] = extractelement <4 x i32> [[A]], i64 0
+// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to double
+// CHECK-NEXT: store double [[C]], ptr [[Gep1]], align 8
+// CHECK-NEXT: [[VL4:%.*]] = extractelement <4 x i32> [[A]], i64 1
+// CHECK-NEXT: [[B:%.*]] = trunc i32 [[VL4]] to i24
+// CHECK-NEXT: [[BFL:%.*]] = load i24, ptr [[E]], align 1
+// CHECK-NEXT: [[BFV:%.*]] = and i24 [[B]], 32767
+// CHECK-NEXT: [[BFC:%.*]] = and i24 [[BFL]], -32768
+// CHECK-NEXT: [[BFSet:%.*]] = or i24 [[BFC]], [[BFV]]
+// CHECK-NEXT: store i24 [[BFSet]], ptr [[E]], align 1
+// CHECK-NEXT: [[VL5:%.*]] = extractelement <4 x i32> [[A]], i64 2
+// CHECK-NEXT: [[C6:%.*]] = sitofp i32 [[VL5]] to float
+// CHECK-NEXT: store float [[C6]], ptr [[Gep2]], align 4
+// CHECK-NEXT: [[VL7:%.*]] = extractelement <4 x i32> [[A]], i64 3
+// CHECK-NEXT: store i32 [[VL7]], ptr [[Gep3]], align 4
+// CHECK-NEXT: ret void
+export void call10(int4 I) {
+  Derived D = (Derived)I;
+}
+
+// truncate derived struct
+// CHECK-LABEL: call11
+// CHECK: [[B:%.*]] = alloca %struct.BFields, align 1
+// CHECK-NEXT: [[Tmp:%.*]] = alloca %struct.Derived, align 1
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 1 [[Tmp]], ptr align 1 [[D]], i32 19, i1 false)
+// CHECK-NEXT: [[Gep:%.*]] = getelementptr inbounds %struct.BFields, ptr [[B]], i32 0
+// CHECK-NEXT: [[E:%.*]] = getelementptr inbounds nuw %struct.BFields, ptr [[Gep]], i32 0, i32 1
+// CHECK-NEXT: [[Gep1:%.*]] = getelementptr inbounds %struct.BFields, ptr [[B]], i32 0, i32 0
+// CHECK-NEXT: [[Gep2:%.*]] = getelementptr inbounds %struct.BFields, ptr [[B]], i32 0, i32 2
+// CHECK-NEXT: [[Gep3:%.*]] = getelementptr inbounds %struct.Derived, ptr [[Tmp]], i32 0, i32 0
+// CHECK-NEXT: [[E4:%.*]] = getelementptr inbounds nuw %struct.BFields, ptr [[Gep3]], i32 0, i32 1
+// CHECK-NEXT: [[Gep5:%.*]] = getelementptr inbounds %struct.Derived, ptr [[Tmp]], i32 0, i32 0, i32 0
+// CHECK-NEXT: [[Gep6:%.*]] = getelementptr inbounds %struct.Derived, ptr [[Tmp]], i32 0, i32 0, i32 2
+// CHECK-NEXT: [[Gep7:%.*]] = getelementptr inbounds %struct.Derived, ptr [[Tmp]], i32 0, i32 1
+// CHECK-NEXT: [[A:%.*]] = load double, ptr [[Gep5]], align 8
+// CHECK-NEXT: store double [[A]], ptr [[Gep1]], align 8
+// CHECK-NEXT: [[BFl:%.*]] = load i24, ptr [[E4]], align 1
+// CHECK-NEXT: [[Shl:%.*]] = shl i24 [[BFL]], 9
+// CHECK-NEXT: [[Ashr:%.*]] = ashr i24 [[Shl]], 9
+// CHECK-NEXT: [[BFC:%.*]] = sext i24 [[Ashr]] to i32
+// CHECK-NEXT: [[B:%.*]] = trunc i32 [[BFC]] to i24
+// CHECK-NEXT: [[BFL8:%.*]] = load i24, ptr [[E]], align 1
+// CHECK-NEXT: [[BFV:%.*]] = and i24 [[B]], 32767
+// CHECK-NEXT: [[BFC:%.*]] = and i24 [[BFL8]], -32768
+// CHECK-NEXT: [[BFSet:%.*]] = or i24 [[BFC]], [[BFV]]
+// CHECK-NEXT: store i24 [[BFSet]], ptr [[E]], align 1
+// CHECK-NEXT: [[C:%.*]] = load float, ptr [[Gep6]], align 4
+// CHECK-NEXT: store float [[C]], ptr [[Gep2]], align 4
+// CHECK-NEXT: ret void
+export void call11(Derived D) {
+  BFields B = (BFields)D;
+}
+
+struct Empty {
+};
+
+// cast to an empty struct
+// CHECK-LABEL: call12
+// CHECK: [[I:%.*]] = alloca <4 x i32>, align 16
+// CHECK-NEXT: [[E:%.*]] = alloca %struct.Empty, align 1
+// CHECK-NEXT: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr [[I]], align 16
+// CHECK-NEXT: [[A:%.*]] = load <4 x i32>, ptr [[I]], align 16
+// CHECK-NEXt: ret void
+export void call12() {
+  int4 I = {1,2,3,4};
+  Empty E = (Empty)I;
+}
diff --git a/clang/test/CodeGenHLSL/BasicFeatures/VectorElementwiseCast.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/VectorElementwiseCast.hlsl
index 253b38a7ca072..26aa41aaf4626 100644
--- a/clang/test/CodeGenHLSL/BasicFeatures/VectorElementwiseCast.hlsl
+++ b/clang/test/CodeGenHLSL/BasicFeatures/VectorElementwiseCast.hlsl
@@ -79,3 +79,45 @@ export void call5() {
  S s = {1, 2.0};
  int A = (int)s;
 }
+
+struct BFields {
+  double D;
+  int E: 15;
+  int : 8;
+  float F;
+};
+
+struct Derived : BFields {
+  int G;
+};
+
+// vector flat cast from derived struct with bitfield
+// CHECK-LABEL: call6
+// CHECK: [[A:%.*]] = alloca <4 x i32>, align 16
+// CHECK-NEXT: [[Tmp:%.*]] = alloca %struct.Derived, align 1
+// CHECK-NEXT: [[FlatTmp:%.*]] = alloca <4 x i32>, align 16
+// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 1 [[Tmp]], ptr align 1 %D, i32 19, i1 false)
+// CHECK-NEXT: [[Gep:%.*]] = getelementptr inbounds %struct.Derived, ptr [[Tmp]], i32 0, i32 0
+// CHECK-NEXT: [[E:%.*]] = getelementptr inbounds nuw %struct.BFields, ptr [[Gep]], i32 0, i32 1
+// CHECK-NEXT: [[Gep1:%.*]] = getelementptr inbounds %struct.Derived, ptr [[Tmp]], i32 0, i32 0, i32 0
+// CHECK-NEXT: [[Gep2:%.*]] = getelementptr inbounds %struct.Derived, ptr [[Tmp]], i32 0, i32 0, i32 2
+// CHECK-NEXT: [[Gep3:%.*]] = getelementptr inbounds %struct.Derived, ptr [[Tmp]], i32 0, i32 1
+// CHECK-NEXT: [[Z:%.*]] = load <4 x i32>, ptr [[FlatTmp]], align 16
+// CHECK-NEXT: [[Y:%.*]] = load double, ptr [[Gep1]], align 8
+// CHECK-NEXT: [[C:%.*]] = fptosi double [[Y]] to i32
+// CHECK-NEXT: [[X:%.*]] = insertelement <4 x i32> [[Z]], i32 [[C]], i64 0
+// CHECK-NEXT: [[BFL:%.*]] = load i24, ptr [[E]], align 1
+// CHECK-NEXT: [[BFShl:%.*]] = shl i24 [[BFL]], 9
+// CHECK-NEXT: [[BFAshr:%.*]] = ashr i24 [[BFShl]], 9
+// CHECK-NEXT: [[BFC:%.*]] = sext i24 [[BFAshr]] to i32
+// CHECK-NEXT: [[W:%.*]] = insertelement <4 x i32> [[X]], i32 [[BFC]], i64 1
+// CHECK-NEXT: [[V:%.*]] = load float, ptr [[Gep2]], align 4
+// CHECK-NEXT: [[C4:%.*]] = fptosi float [[V]] to i32
+// CHECK-NEXT: [[U:%.*]] = insertelement <4 x i32> [[W]], i32 [[C4]], i64 2
+// CHECK-NEXT: [[T:%.*]] = load i32, ptr [[Gep3]], align 4
+// CHECK-NEXT: [[S:%.*]] = insertelement <4 x i32> [[U]], i32 [[T]], i64 3
+// CHECK-NEXT: store <4 x i32> [[S]], ptr [[A]], align 16
+// CHECK-NEXT: ret void
+export void call6(Derived D) {
+  int4 A = (int4)D;
+}
diff --git a/clang/test/SemaHLSL/Language/AggregateSplatCast-errors.hlsl b/clang/test/SemaHLSL/Language/AggregateSplatCast-errors.hlsl
index 2320e136b2536..fbb47bd2e7d39 100644
--- a/clang/test/SemaHLSL/Language/AggregateSplatCast-errors.hlsl
+++ b/clang/test/SemaHLSL/Language/AggregateSplatCast-errors.hlsl
@@ -13,12 +13,6 @@ struct R {
   };
 };
 
-// casting types which contain bitfields is not yet supported.
-export void cantCast() {
-  S s = (S)1;
-  // expected-error at -1 {{no matching conversion for C-style cast from 'int' to 'S'}}
-}
-
 // Can't cast a union
 export void cantCast2() {
   R r = (R)1;
diff --git a/clang/test/SemaHLSL/Language/ElementwiseCast-errors.hlsl b/clang/test/SemaHLSL/Language/ElementwiseCast-errors.hlsl
index 30591507b3260..d9f50e9b0307f 100644
--- a/clang/test/SemaHLSL/Language/ElementwiseCast-errors.hlsl
+++ b/clang/test/SemaHLSL/Language/ElementwiseCast-errors.hlsl
@@ -7,27 +7,6 @@ export void cantCast() {
   // expected-error at -1 {{C-style cast from 'int[3]' to 'int[4]' is not allowed}}
 }
 
-struct S {
-// expected-note at -1 {{candidate constructor (the implicit copy constructor) not viable: no known conversion from 'int2' (aka 'vector<int, 2>') to 'const S' for 1st argument}}
-// expected-note at -2 {{candidate constructor (the implicit move constructor) not viable: no known conversion from 'int2' (aka 'vector<int, 2>') to 'S' for 1st argument}}
-// expected-note at -3 {{candidate constructor (the implicit default constructor) not viable: requires 0 arguments, but 1 was provided}}
-  int A : 8;
-  int B;
-};
-
-// casting types which contain bitfields is not yet supported.
-export void cantCast2() {
-  S s = {1,2};
-  int2 C = (int2)s;
-  // expected-error at -1 {{cannot convert 'S' to 'int2' (aka 'vector<int, 2>') without a conversion operator}}
-}
-
-export void cantCast3() {
-  int2 C = {1,2};
-  S s = (S)C;
-  // expected-error at -1 {{no matching conversion for C-style cast from 'int2' (aka 'vector<int, 2>') to 'S'}}
-}
-
 struct R {
 // expected-note at -1 {{candidate constructor (the implicit copy constructor) not viable: no known conversion from 'int2' (aka 'vector<int, 2>') to 'const R' for 1st argument}}
 // expected-note at -2 {{candidate constructor (the implicit move constructor) not viable: no known conversion from 'int2' (aka 'vector<int, 2>') to 'R' for 1st argument}}

>From 6374c41ab40ff0ab6883d8c252621202021fb214 Mon Sep 17 00:00:00 2001
From: Sarah Spall <sarahspall at microsoft.com>
Date: Thu, 2 Oct 2025 11:46:45 -0700
Subject: [PATCH 2/2] respond to pr comments

---
 clang/lib/CodeGen/CGExpr.cpp                  | 10 ++-
 clang/lib/CodeGen/CGExprScalar.cpp            |  3 +-
 .../BasicFeatures/AggregateSplatCast.hlsl     | 20 +++---
 .../BasicFeatures/StructElementwiseCast.hlsl  | 67 ++++++++++++++++++-
 4 files changed, 81 insertions(+), 19 deletions(-)

diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 04e2b64d2bd7c..9f30287b68c79 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -6803,7 +6803,7 @@ void CodeGenFunction::FlattenAccessAndTypeLValue(
       for (int64_t I = Size - 1; I > -1; I--) {
         llvm::SmallVector<llvm::Value *, 4> IdxListCopy = IdxList;
         IdxListCopy.push_back(llvm::ConstantInt::get(IdxTy, I));
-        WorkList.push_back({LVal, CAT->getElementType(), IdxListCopy});
+        WorkList.emplace_back(LVal, CAT->getElementType(), IdxListCopy);
       }
     } else if (const auto *RT = dyn_cast<RecordType>(T)) {
       const RecordDecl *Record = RT->getOriginalDecl()->getDefinitionOrSelf();
@@ -6826,8 +6826,7 @@ void CodeGenFunction::FlattenAccessAndTypeLValue(
           llvm::SmallVector<llvm::Value *, 4> IdxListCopy = IdxList;
           IdxListCopy.push_back(llvm::ConstantInt::get(
               IdxTy, 0)); // base struct should be at index zero
-          ReverseList.insert(ReverseList.end(),
-                             {LVal, Base->getType(), IdxListCopy});
+          ReverseList.emplace_back(LVal, Base->getType(), IdxListCopy);
         }
       }
 
@@ -6848,13 +6847,12 @@ void CodeGenFunction::FlattenAccessAndTypeLValue(
             RLValue = MakeAddrLValue(GEP, T);
           }
           LValue FieldLVal = EmitLValueForField(RLValue, FD, true);
-          ReverseList.insert(ReverseList.end(), {FieldLVal, FD->getType(), {}});
+          ReverseList.push_back({FieldLVal, FD->getType(), {}});
         } else {
           llvm::SmallVector<llvm::Value *, 4> IdxListCopy = IdxList;
           IdxListCopy.push_back(
               llvm::ConstantInt::get(IdxTy, Layout.getLLVMFieldNo(FD)));
-          ReverseList.insert(ReverseList.end(),
-                             {LVal, FD->getType(), IdxListCopy});
+          ReverseList.emplace_back(LVal, FD->getType(), IdxListCopy);
         }
       }
 
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 00ebf41b315bb..83a2129fb8b65 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -2400,7 +2400,8 @@ static Value *EmitHLSLElementwiseCast(CodeGenFunction &CGF, LValue SrcVal,
   // if its a vector create a temp alloca to store into and return that
   if (auto *VecTy = DestTy->getAs<VectorType>()) {
     assert(LoadList.size() >= VecTy->getNumElements() &&
-           "Flattened type on RHS must have more elements than vector on LHS.");
+           "Flattened type on RHS must have the same number or more elements "
+           "than vector on LHS.");
     llvm::Value *V =
         CGF.Builder.CreateLoad(CGF.CreateIRTemp(DestTy, "flatcast.tmp"));
     // write to V.
diff --git a/clang/test/CodeGenHLSL/BasicFeatures/AggregateSplatCast.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/AggregateSplatCast.hlsl
index 63b2b50dd9885..9524f024e8d46 100644
--- a/clang/test/CodeGenHLSL/BasicFeatures/AggregateSplatCast.hlsl
+++ b/clang/test/CodeGenHLSL/BasicFeatures/AggregateSplatCast.hlsl
@@ -54,18 +54,16 @@ struct S {
 
 // struct splats
 // CHECK-LABEL: define void {{.*}}call3
-// CHECK: [[A:%.*]] = alloca <1 x i32>, align 4
+// CHECK: [[AA:%.*]] = alloca i32, align 4
 // CHECK: [[s:%.*]] = alloca %struct.S, align 1
-// CHECK-NEXT: store <1 x i32> splat (i32 1), ptr [[A]], align 4
-// CHECK-NEXT: [[L:%.*]] = load <1 x i32>, ptr [[A]], align 4
-// CHECK-NEXT: [[VL:%.*]] = extractelement <1 x i32> [[L]], i32 0
+// CHECK-NEXT: store i32 %A, ptr [[AA]], align 4
+// CHECK-NEXT: [[L:%.*]] = load i32, ptr [[AA]], align 4
 // CHECK-NEXT: [[G1:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 0
 // CHECK-NEXT: [[G2:%.*]] = getelementptr inbounds %struct.S, ptr [[s]], i32 0, i32 1
-// CHECK-NEXT: store i32 [[VL]], ptr [[G1]], align 4
-// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[VL]] to float
+// CHECK-NEXT: store i32 [[L]], ptr [[G1]], align 4
+// CHECK-NEXT: [[C:%.*]] = sitofp i32 [[L]] to float
 // CHECK-NEXT: store float [[C]], ptr [[G2]], align 4
-export void call3() {
-  int1 A = {1};
+export void call3(int A) {
   S s = (S)A;
 }
 
@@ -87,7 +85,7 @@ export void call5() {
 }
 
 struct BFields {
-  double D;
+  double DF;
   int E: 15;
   int : 8;
   float F;
@@ -110,9 +108,9 @@ struct Derived : BFields {
 // CHECK-NEXT: [[Gep3:%.*]] = getelementptr inbounds %struct.Derived, ptr [[D]], i32 0, i32 1
 // CHECK-NEXT: [[C:%.*]] = sitofp i32 [[B]] to double
 // CHECK-NEXT: store double [[C]], ptr [[Gep1]], align 8
-// CHECK-NEXT: [[D:%.*]] = trunc i32 [[B]] to i24
+// CHECK-NEXT: [[H:%.*]] = trunc i32 [[B]] to i24
 // CHECK-NEXT: [[BFL:%.*]] = load i24, ptr [[E]], align 1
-// CHECK-NEXT: [[BFV:%.*]] = and i24 [[D]], 32767
+// CHECK-NEXT: [[BFV:%.*]] = and i24 [[H]], 32767
 // CHECK-NEXT: [[BFC:%.*]] = and i24 [[BFL]], -32768
 // CHECK-NEXT: [[BFS:%.*]] = or i24 [[BFC]], [[BFV]]
 // CHECK-NEXT: store i24 [[BFS]], ptr [[E]], align 1
diff --git a/clang/test/CodeGenHLSL/BasicFeatures/StructElementwiseCast.hlsl b/clang/test/CodeGenHLSL/BasicFeatures/StructElementwiseCast.hlsl
index 39928e87c4df9..09c73c266196b 100644
--- a/clang/test/CodeGenHLSL/BasicFeatures/StructElementwiseCast.hlsl
+++ b/clang/test/CodeGenHLSL/BasicFeatures/StructElementwiseCast.hlsl
@@ -1,4 +1,4 @@
-// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s
+// RUN: %clang_cc1 -finclude-default-header -fnative-half-type -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s
 
 struct S {
   int X;
@@ -245,3 +245,68 @@ export void call12() {
   int4 I = {1,2,3,4};
   Empty E = (Empty)I;
 }
+
+struct MoreBFields {
+  int A;
+  uint64_t B: 60;
+  float C;
+  uint16_t D: 10;
+  uint16_t E: 6;
+  int : 32;
+  double F;
+  int : 8;
+  uint G;
+};
+
+// more complicated bitfield case
+// CHECK-LABEL: call13
+// CHECK: [[AA:%.*]] = alloca i32, align 4
+// CHECK-NEXT: [[MBF:%.*]] = alloca %struct.MoreBFields, align 1
+// CHECK-NEXT: store i32 %A, ptr [[AA]], align 4
+// CHECK-NEXT: [[Z:%.*]] = load i32, ptr [[AA]], align 4
+// get the gep for the struct.
+// CHECK-NEXT: [[Gep:%.*]] = getelementptr inbounds %struct.MoreBFields, ptr [[MBF]], i32 0
+// CHECK-NEXT: [[FieldB:%.*]] = getelementptr inbounds nuw %struct.MoreBFields, ptr [[Gep]], i32 0, i32 1
+// D and E share the same field index
+// CHECK-NEXT: [[FieldD:%.*]] = getelementptr inbounds nuw %struct.MoreBFields, ptr [[Gep]], i32 0, i32 3
+// CHECK-NEXT: [[FieldE:%.*]] = getelementptr inbounds nuw %struct.MoreBFields, ptr [[Gep]], i32 0, i32 3
+// CHECK-NEXT: [[FieldA:%.*]] = getelementptr inbounds %struct.MoreBFields, ptr [[MBF]], i32 0, i32 0
+// CHECK-NEXT: [[FieldC:%.*]] = getelementptr inbounds %struct.MoreBFields, ptr [[MBF]], i32 0, i32 2
+// CHECK-NEXT: [[FieldF:%.*]] = getelementptr inbounds %struct.MoreBFields, ptr [[MBF]], i32 0, i32 5
+// CHECK-NEXT: [[FieldG:%.*]] = getelementptr inbounds %struct.MoreBFields, ptr [[MBF]], i32 0, i32 7
+// store int A into field A
+// CHECK-NEXT: store i32 [[Z]], ptr [[FieldA]], align 4
+// store int A in bitField B, do necessary conversions
+// CHECK-NEXT: [[Conv:%.*]] = sext i32 [[Z]] to i64
+// CHECK-NEXT: [[BFL:%.*]] = load i64, ptr [[FieldB]], align 1
+// CHECK-NEXT: [[BFV:%.*]] = and i64 [[Conv]], 1152921504606846975
+// CHECK-NEXT: [[BFC:%.*]] = and i64 [[BFL]], -1152921504606846976
+// CHECK-NEXT: [[BFS:%.*]] = or i64 [[BFC]], [[BFV]]
+// CHECK-NEXT: store i64 [[BFS]], ptr [[FieldB]], align 1
+// store int A into field C
+// CHECK-NEXT: [[Conv5:%.*]] = sitofp i32 [[Z]] to float
+// CHECK-NEXT: store float [[Conv5]], ptr [[FieldC]], align 4
+// store int A into bitfield D
+// CHECK-NEXT: [[Conv6:%.*]] = trunc i32 [[Z]] to i16
+// CHECK-NEXT: [[FDL:%.*]] = load i16, ptr [[FieldD]], align 1
+// CHECK-NEXT: [[FDV:%.*]] = and i16 [[Conv6]], 1023
+// CHECK-NEXT: [[FDC:%.*]] = and i16 [[FDL]], -1024
+// CHECK-NEXT: [[FDS:%.*]] = or i16 [[FDC]], [[FDV]]
+// CHECK-NEXT: store i16 [[FDS]], ptr [[FieldD]], align 1
+// store int A into bitfield E;
+// CHECK-NEXT: [[Conv11:%.*]] = trunc i32 [[Z]] to i16
+// CHECK-NEXT: [[FEL:%.*]] = load i16, ptr [[FieldE]], align 1
+// CHECK-NEXT: [[FEV:%.*]] = and i16 [[Conv11]], 63
+// CHECK-NEXT: [[FESHL:%.*]] = shl i16 [[FEV]], 10
+// CHECK-NEXT: [[FEC:%.*]] = and i16 [[FEL]], 1023
+// CHECK-NEXT: [[FES:%.*]] = or i16 [[FEC]], [[FESHL]]
+// CHECK-NEXT: store i16 [[FES]], ptr [[FieldE]], align 1
+// store int A into field F
+// CHECK-NEXT: [[Conv16:%.*]] = sitofp i32 [[Z]] to double
+// CHECK-NEXT: store double [[Conv16]], ptr [[FieldF]], align 8
+// store int A into field G
+// CHECK-NEXT: store i32 [[Z]], ptr [[FieldG]], align 4
+// CHECK-NEXT: ret void
+export void call13(int A) {
+  MoreBFields MBF = (MoreBFields)A;
+}



More information about the cfe-commits mailing list