[clang] [HLSL] add support for HLSLAggregateSplatCast and HLSLElementwiseCast to constant expression evaluator (PR #164700)

Joshua Batista via cfe-commits cfe-commits at lists.llvm.org
Wed Oct 29 13:08:14 PDT 2025


================
@@ -3829,6 +3829,330 @@ static bool CheckArraySize(EvalInfo &Info, const ConstantArrayType *CAT,
       /*Diag=*/true);
 }
 
+static bool handleScalarCast(EvalInfo &Info, const FPOptions FPO, const Expr *E,
+                             QualType SourceTy, QualType DestTy,
+                             APValue const &Original, APValue &Result) {
+  // boolean must be checked before integer
+  // since IsIntegerType() is true for bool
+  if (SourceTy->isBooleanType()) {
+    if (DestTy->isBooleanType()) {
+      Result = Original;
+      return true;
+    }
+    if (DestTy->isIntegerType() || DestTy->isRealFloatingType()) {
+      bool BoolResult;
+      if (!HandleConversionToBool(Original, BoolResult))
+        return false;
+      uint64_t IntResult = BoolResult;
+      QualType IntType = DestTy->isIntegerType()
+                             ? DestTy
+                             : Info.Ctx.getIntTypeForBitwidth(64, false);
+      Result = APValue(Info.Ctx.MakeIntValue(IntResult, IntType));
+    }
+    if (DestTy->isFloatingType()) {
+      APValue Result2 = APValue(APFloat(0.0));
+      if (!HandleIntToFloatCast(Info, E, FPO,
+                                Info.Ctx.getIntTypeForBitwidth(64, false),
+                                Result.getInt(), DestTy, Result2.getFloat()))
+        return false;
+      Result = Result2;
+    }
+    return true;
+  }
+  if (SourceTy->isIntegerType()) {
+    if (DestTy->isRealFloatingType()) {
+      Result = APValue(APFloat(0.0));
+      return HandleIntToFloatCast(Info, E, FPO, SourceTy, Original.getInt(),
+                                  DestTy, Result.getFloat());
+    }
+    if (DestTy->isBooleanType()) {
+      bool BoolResult;
+      if (!HandleConversionToBool(Original, BoolResult))
+        return false;
+      uint64_t IntResult = BoolResult;
+      Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy));
+      return true;
+    }
+    if (DestTy->isIntegerType()) {
+      Result = APValue(
+          HandleIntToIntCast(Info, E, DestTy, SourceTy, Original.getInt()));
+      return true;
+    }
+  } else if (SourceTy->isRealFloatingType()) {
+    if (DestTy->isRealFloatingType()) {
+      Result = Original;
+      return HandleFloatToFloatCast(Info, E, SourceTy, DestTy,
+                                    Result.getFloat());
+    }
+    if (DestTy->isBooleanType()) {
+      bool BoolResult;
+      if (!HandleConversionToBool(Original, BoolResult))
+        return false;
+      uint64_t IntResult = BoolResult;
+      Result = APValue(Info.Ctx.MakeIntValue(IntResult, DestTy));
+      return true;
+    }
+    if (DestTy->isIntegerType()) {
+      Result = APValue(APSInt());
+      return HandleFloatToIntCast(Info, E, SourceTy, Original.getFloat(),
+                                  DestTy, Result.getInt());
+    }
+  }
+
+  return false;
+}
+
+// do the heavy lifting for casting to aggregate types
+// because we have to deal with bitfields specially
+static bool constructAggregate(EvalInfo &Info, const FPOptions FPO,
+                               const Expr *E, APValue &Result,
+                               QualType ResultType,
+                               SmallVectorImpl<APValue> &Elements,
+                               SmallVectorImpl<QualType> &ElTypes) {
+
+  SmallVector<std::tuple<APValue *, QualType, unsigned>> WorkList = {
+      {&Result, ResultType, 0}};
+
+  unsigned ElI = 0;
+  while (!WorkList.empty() && ElI < Elements.size()) {
+    auto [Res, Type, BitWidth] = WorkList.pop_back_val();
+
+    if (Type->isRealFloatingType()) {
+      if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], Type, Elements[ElI],
+                            *Res))
+        return false;
+      ElI++;
+      continue;
+    }
+    if (Type->isIntegerType()) {
+      if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], Type, Elements[ElI],
+                            *Res))
+        return false;
+      if (BitWidth > 0) {
+        if (!Res->isInt())
+          return false;
+        APSInt &Int = Res->getInt();
+        unsigned OldBitWidth = Int.getBitWidth();
+        unsigned NewBitWidth = BitWidth;
+        if (NewBitWidth < OldBitWidth)
+          Int = Int.trunc(NewBitWidth).extend(OldBitWidth);
+      }
+      ElI++;
+      continue;
+    }
+    if (Type->isVectorType()) {
+      QualType ElTy = Type->castAs<VectorType>()->getElementType();
+      unsigned NumEl = Type->castAs<VectorType>()->getNumElements();
+      SmallVector<APValue> Vals(NumEl);
+      for (unsigned I = 0; I < NumEl; ++I) {
+        if (!handleScalarCast(Info, FPO, E, ElTypes[ElI], ElTy, Elements[ElI],
+                              Vals[I]))
+          return false;
+        ElI++;
+      }
+      *Res = APValue(Vals.data(), NumEl);
+      continue;
+    }
+    if (Type->isConstantArrayType()) {
+      QualType ElTy = cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))
+                          ->getElementType();
+      uint64_t Size =
+          cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))->getZExtSize();
+      *Res = APValue(APValue::UninitArray(), Size, Size);
+      for (int64_t I = Size - 1; I > -1; --I) {
+        WorkList.emplace_back(&Res->getArrayInitializedElt(I), ElTy, 0u);
+      }
+      continue;
+    }
+    if (Type->isRecordType()) {
+      const RecordDecl *RD = Type->getAsRecordDecl();
+
+      unsigned NumBases = 0;
+      if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD))
+        NumBases = CXXRD->getNumBases();
+
+      *Res = APValue(APValue::UninitStruct(), NumBases,
+                     std::distance(RD->field_begin(), RD->field_end()));
+
+      SmallVector<std::tuple<APValue *, QualType, unsigned>> ReverseList;
+      // we need to traverse backwards
+      // Visit the base classes.
+      if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+        if (CXXRD->getNumBases() > 0) {
+          assert(CXXRD->getNumBases() == 1);
+          const CXXBaseSpecifier &BS = CXXRD->bases_begin()[0];
+          ReverseList.emplace_back(&Res->getStructBase(0), BS.getType(), 0u);
+        }
+      }
+
+      // Visit the fields.
+      for (FieldDecl *FD : RD->fields()) {
+        unsigned FDBW = 0;
+        if (FD->isUnnamedBitField())
+          continue;
+        if (FD->isBitField()) {
+          FDBW = FD->getBitWidthValue();
+        }
+
+        ReverseList.emplace_back(&Res->getStructField(FD->getFieldIndex()),
+                                 FD->getType(), FDBW);
+      }
+
+      std::reverse(ReverseList.begin(), ReverseList.end());
+      llvm::append_range(WorkList, ReverseList);
+      continue;
+    }
+    return false;
+  }
+  return true;
+}
+
+static bool handleElementwiseCast(EvalInfo &Info, const Expr *E,
+                                  const FPOptions FPO,
+                                  SmallVectorImpl<APValue> &Elements,
+                                  SmallVectorImpl<QualType> &SrcTypes,
+                                  SmallVectorImpl<QualType> &DestTypes,
+                                  SmallVectorImpl<APValue> &Results) {
+
+  assert((Elements.size() == SrcTypes.size()) &&
+         (Elements.size() == DestTypes.size()));
+
+  for (unsigned I = 0, ESz = Elements.size(); I < ESz; ++I) {
+    APValue Original = Elements[I];
+    QualType SourceTy = SrcTypes[I];
+    QualType DestTy = DestTypes[I];
+
+    if (!handleScalarCast(Info, FPO, E, SourceTy, DestTy, Original, Results[I]))
+      return false;
+  }
+  return true;
+}
+
+static unsigned elementwiseSize(EvalInfo &Info, QualType BaseTy) {
+
+  SmallVector<QualType> WorkList = {BaseTy};
+
+  unsigned Size = 0;
+  while (!WorkList.empty()) {
+    QualType Type = WorkList.pop_back_val();
+    if (Type->isRealFloatingType() || Type->isIntegerType() ||
+        Type->isBooleanType()) {
+      ++Size;
+      continue;
+    }
+    if (Type->isVectorType()) {
+      unsigned NumEl = Type->castAs<VectorType>()->getNumElements();
+      Size += NumEl;
+      continue;
+    }
+    if (Type->isConstantArrayType()) {
+      QualType ElTy = cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))
+                          ->getElementType();
+      uint64_t Size =
+          cast<ConstantArrayType>(Info.Ctx.getAsArrayType(Type))->getZExtSize();
+      for (uint64_t I = 0; I < Size; ++I) {
+        WorkList.push_back(ElTy);
+      }
+      continue;
+    }
+    if (Type->isRecordType()) {
+      const RecordDecl *RD = Type->getAsRecordDecl();
+
+      // Visit the base classes.
+      if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+        if (CXXRD->getNumBases() > 0) {
+          assert(CXXRD->getNumBases() == 1);
+          const CXXBaseSpecifier &BS = CXXRD->bases_begin()[0];
+          WorkList.push_back(BS.getType());
+        }
+      }
+
+      // visit the fields.
+      for (FieldDecl *FD : RD->fields()) {
+        if (FD->isUnnamedBitField())
+          continue;
+        WorkList.push_back(FD->getType());
+      }
+      continue;
+    }
+  }
+  return Size;
+}
+
+static bool flattenAPValue(const ASTContext &Ctx, APValue Value,
+                           QualType BaseTy, SmallVectorImpl<APValue> &Elements,
+                           SmallVectorImpl<QualType> &Types, unsigned Size) {
+
+  SmallVector<std::pair<APValue, QualType>> WorkList = {{Value, BaseTy}};
+  unsigned Populated = 0;
+  while (!WorkList.empty() && Populated < Size) {
+    auto [Work, Type] = WorkList.pop_back_val();
+
+    if (Work.isFloat() || Work.isInt()) { // todo what does this do with bool
+      Elements.push_back(Work);
+      Types.push_back(Type);
+      Populated++;
+      continue;
+    }
+    if (Work.isVector()) {
+      assert(Type->isVectorType() && "Type mismatch.");
+      QualType ElTy = Type->castAs<VectorType>()->getElementType();
+      for (unsigned I = 0; I < Work.getVectorLength() && Populated < Size;
+           I++) {
+        Elements.push_back(Work.getVectorElt(I));
+        Types.push_back(ElTy);
+        Populated++;
+      }
+      continue;
+    }
+    if (Work.isArray()) {
+      assert(Type->isConstantArrayType() && "Type mismatch.");
+      QualType ElTy =
+          cast<ConstantArrayType>(Ctx.getAsArrayType(Type))->getElementType();
+      for (int64_t I = Work.getArraySize() - 1; I > -1; --I) {
+        WorkList.emplace_back(Work.getArrayInitializedElt(I), ElTy);
+      }
+      continue;
+    }
+
+    if (Work.isStruct()) {
+      assert(Type->isRecordType() && "Type mismatch.");
+
+      const RecordDecl *RD = Type->getAsRecordDecl();
+
+      SmallVector<std::pair<APValue, QualType>> ReverseList;
+      // Visit the fields.
+      for (FieldDecl *FD : RD->fields()) {
+        if (FD->isUnnamedBitField())
+          continue;
+        ReverseList.emplace_back(Work.getStructField(FD->getFieldIndex()),
+                                 FD->getType());
+      }
+
+      std::reverse(ReverseList.begin(), ReverseList.end());
+      llvm::append_range(WorkList, ReverseList);
+
+      // Visit the base classes.
+      if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
+        if (CXXRD->getNumBases() > 0) {
+          assert(CXXRD->getNumBases() == 1);
+          const CXXBaseSpecifier &BS = CXXRD->bases_begin()[0];
+          const APValue &Base = Work.getStructBase(0);
+
+          // Can happen in error cases.
+          if (!Base.isStruct())
+            return false;
+
+          WorkList.emplace_back(Base, BS.getType());
+        }
+      }
+      continue;
+    }
+    return false;
+  }
+  return true;
----------------
bob80905 wrote:

Do you want to return true here when it could be the case that Populated >= Size and there still are remaining items in the Worklist? I think that would indicate that the flattening was unable to be performed completely within the given size. 
So I would've expected something like `return Worklist.empty()` here.

https://github.com/llvm/llvm-project/pull/164700


More information about the cfe-commits mailing list