[clang] [HLSL] Implement HLSL intialization list support (PR #123141)

Farzon Lotfi via cfe-commits cfe-commits at lists.llvm.org
Tue Feb 11 14:58:09 PST 2025


================
@@ -3010,3 +3010,183 @@ void SemaHLSL::processExplicitBindingsOnDecl(VarDecl *VD) {
     }
   }
 }
+
+static bool CastInitializer(Sema &S, ASTContext &Ctx, Expr *E,
+                            llvm::SmallVectorImpl<Expr *> &List,
+                            llvm::SmallVectorImpl<QualType> &DestTypes) {
+  if (List.size() >= DestTypes.size()) {
+    List.push_back(E);
+    return true;
+  }
+  InitializedEntity Entity = InitializedEntity::InitializeParameter(
+      Ctx, DestTypes[List.size()], false);
+  ExprResult Res = S.PerformCopyInitialization(Entity, E->getBeginLoc(), E);
+  if (Res.isInvalid())
+    return false;
+  Expr *Init = Res.get();
+  List.push_back(Init);
+  return true;
+}
+
+static void BuildInitializerList(Sema &S, ASTContext &Ctx, Expr *E,
+                                 llvm::SmallVectorImpl<Expr *> &List,
+                                 llvm::SmallVectorImpl<QualType> &DestTypes,
+                                 bool &ExcessInits) {
+  if (List.size() >= DestTypes.size())
+    ExcessInits = true;
+
+  // If this is an initialization list, traverse the sub initializers.
+  if (auto *Init = dyn_cast<InitListExpr>(E)) {
+    for (auto *SubInit : Init->inits())
+      BuildInitializerList(S, Ctx, SubInit, List, DestTypes, ExcessInits);
+    return;
+  }
+
+  // If this is a scalar type, just enqueue the expression.
+  QualType Ty = E->getType();
+  if (Ty->isScalarType()) {
+    (void)CastInitializer(S, Ctx, E, List, DestTypes);
+    return;
+  }
+
+  if (auto *ATy = Ty->getAs<VectorType>()) {
+    uint64_t Size = ATy->getNumElements();
+
+    if (List.size() + Size > DestTypes.size())
+      ExcessInits = true;
+    QualType SizeTy = Ctx.getSizeType();
+    uint64_t SizeTySize = Ctx.getTypeSize(SizeTy);
+    for (uint64_t I = 0; I < Size; ++I) {
+      auto *Idx = IntegerLiteral::Create(Ctx, llvm::APInt(SizeTySize, I),
+                                         SizeTy, SourceLocation());
+
+      ExprResult ElExpr = S.CreateBuiltinArraySubscriptExpr(
+          E, E->getBeginLoc(), Idx, E->getEndLoc());
+      if (ElExpr.isInvalid())
+        return;
+      CastInitializer(S, Ctx, ElExpr.get(), List, DestTypes);
+    }
+    return;
+  }
+
+  if (auto *VTy = dyn_cast<ConstantArrayType>(Ty.getTypePtr())) {
+    uint64_t Size = VTy->getZExtSize();
+    QualType SizeTy = Ctx.getSizeType();
+    uint64_t SizeTySize = Ctx.getTypeSize(SizeTy);
+    for (uint64_t I = 0; I < Size; ++I) {
+      auto *Idx = IntegerLiteral::Create(Ctx, llvm::APInt(SizeTySize, I),
+                                         SizeTy, SourceLocation());
+      ExprResult ElExpr = S.CreateBuiltinArraySubscriptExpr(
+          E, E->getBeginLoc(), Idx, E->getEndLoc());
+      if (ElExpr.isInvalid())
+        return;
+      BuildInitializerList(S, Ctx, ElExpr.get(), List, DestTypes, ExcessInits);
+    }
+    return;
+  }
+
+  if (auto *RTy = Ty->getAs<RecordType>()) {
+    llvm::SmallVector<const RecordType *> RecordTypes;
+    RecordTypes.push_back(RTy);
+    while (RecordTypes.back()->getAsCXXRecordDecl()->getNumBases()) {
+      CXXRecordDecl *D = RecordTypes.back()->getAsCXXRecordDecl();
+      assert(D->getNumBases() == 1 &&
+             "HLSL doesn't support multiple inheritance");
+      RecordTypes.push_back(D->bases_begin()->getType()->getAs<RecordType>());
+    }
+    while (!RecordTypes.empty()) {
+      const RecordType *RT = RecordTypes.back();
+      RecordTypes.pop_back();
+      for (auto *FD : RT->getDecl()->fields()) {
+        DeclAccessPair Found = DeclAccessPair::make(FD, FD->getAccess());
+        DeclarationNameInfo NameInfo(FD->getDeclName(), E->getBeginLoc());
+        ExprResult Res = S.BuildFieldReferenceExpr(
+            E, false, E->getBeginLoc(), CXXScopeSpec(), FD, Found, NameInfo);
+        if (Res.isInvalid())
+          return;
+        BuildInitializerList(S, Ctx, Res.get(), List, DestTypes, ExcessInits);
+      }
+    }
+  }
+}
+
+static Expr *GenerateInitLists(ASTContext &Ctx, QualType Ty,
+                               llvm::SmallVectorImpl<Expr *>::iterator &It) {
+  if (Ty->isScalarType()) {
+    return *(It++);
+  }
+  llvm::SmallVector<Expr *> Inits;
+  assert(!isa<MatrixType>(Ty) && "Matrix types not yet supported in HLSL");
+  Ty = Ty.getDesugaredType(Ctx);
+  if (Ty->isVectorType() || Ty->isConstantArrayType()) {
+    QualType ElTy;
+    uint64_t Size = 0;
+    if (auto *ATy = Ty->getAs<VectorType>()) {
+      ElTy = ATy->getElementType();
+      Size = ATy->getNumElements();
+    } else {
+      auto *VTy = cast<ConstantArrayType>(Ty.getTypePtr());
+      ElTy = VTy->getElementType();
+      Size = VTy->getZExtSize();
+    }
+    for (uint64_t I = 0; I < Size; ++I)
+      Inits.push_back(GenerateInitLists(Ctx, ElTy, It));
+  }
+  if (auto *RTy = Ty->getAs<RecordType>()) {
+    llvm::SmallVector<const RecordType *> RecordTypes;
+    RecordTypes.push_back(RTy);
+    while (RecordTypes.back()->getAsCXXRecordDecl()->getNumBases()) {
+      CXXRecordDecl *D = RecordTypes.back()->getAsCXXRecordDecl();
+      assert(D->getNumBases() == 1 &&
+             "HLSL doesn't support multiple inheritance");
+      RecordTypes.push_back(D->bases_begin()->getType()->getAs<RecordType>());
+    }
+    while (!RecordTypes.empty()) {
+      const RecordType *RT = RecordTypes.back();
+      RecordTypes.pop_back();
+      for (auto *FD : RT->getDecl()->fields()) {
+        Inits.push_back(GenerateInitLists(Ctx, FD->getType(), It));
+      }
+    }
+  }
+  auto *NewInit = new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(),
+                                         Inits, Inits.back()->getEndLoc());
+  NewInit->setType(Ty);
+  return NewInit;
+}
+
+bool SemaHLSL::TransformInitList(const InitializedEntity &Entity,
+                                 const InitializationKind &Kind,
+                                 InitListExpr *Init) {
+  // If the initializer is a scalar, just return it.
+  if (Init->getType()->isScalarType())
+    return true;
+  ASTContext &Ctx = SemaRef.getASTContext();
+  llvm::SmallVector<QualType, 16> DestTypes;
+  // An initializer list might be attempting to initialize a reference or
+  // rvalue-reference. When checking the initializer we should look through the
+  // reference.
+  QualType InitTy = Entity.getType().getNonReferenceType();
+  BuildFlattenedTypeList(InitTy, DestTypes);
+
+  llvm::SmallVector<Expr *, 16> ArgExprs;
+  bool ExcessInits = false;
+  for (Expr *Arg : Init->inits())
+    BuildInitializerList(SemaRef, Ctx, Arg, ArgExprs, DestTypes, ExcessInits);
+
+  if (DestTypes.size() != ArgExprs.size() || ExcessInits) {
+    int TooManyOrFew = ExcessInits ? 1 : 0;
+    SemaRef.Diag(Init->getBeginLoc(), diag::err_hlsl_incorrect_num_initializers)
----------------
farzonl wrote:

The `note_ovl_candidate_bad_list_argument` note does almost everything you want except error you are returning a false here anyways, do we need a special diagnostic for this? If so does it need to be hlsl specifc. seems like it could be usefull in other cases  maybe replace `err_excess_initializers` with one that can `say too few` or `too many`?

https://github.com/llvm/llvm-project/pull/123141


More information about the cfe-commits mailing list