[clang-tools-extra] 2ec71d9 - [clang] NFC: introduce Type::getAsEnumDecl, and cast variants for all TagDecls (#155463)

via cfe-commits cfe-commits at lists.llvm.org
Tue Aug 26 12:06:03 PDT 2025


Author: Matheus Izvekov
Date: 2025-08-26T16:05:59-03:00
New Revision: 2ec71d93ad888d9523425930ef8c35fe8f0b2485

URL: https://github.com/llvm/llvm-project/commit/2ec71d93ad888d9523425930ef8c35fe8f0b2485
DIFF: https://github.com/llvm/llvm-project/commit/2ec71d93ad888d9523425930ef8c35fe8f0b2485.diff

LOG: [clang] NFC: introduce Type::getAsEnumDecl, and cast variants for all TagDecls (#155463)

And make use of those.

These changes are split from prior PR #155028, in order to decrease the
size of that PR and facilitate review.

Added: 
    

Modified: 
    clang-tools-extra/clang-tidy/bugprone/TaggedUnionMemberCountCheck.cpp
    clang-tools-extra/clang-tidy/utils/ExceptionSpecAnalyzer.cpp
    clang-tools-extra/clang-tidy/utils/FormatStringConverter.cpp
    clang-tools-extra/clangd/Hover.cpp
    clang/include/clang/AST/Decl.h
    clang/include/clang/AST/Type.h
    clang/lib/AST/APValue.cpp
    clang/lib/AST/ASTContext.cpp
    clang/lib/AST/ByteCode/Compiler.cpp
    clang/lib/AST/ByteCode/Context.cpp
    clang/lib/AST/ByteCode/Program.cpp
    clang/lib/AST/CXXInheritance.cpp
    clang/lib/AST/DeclCXX.cpp
    clang/lib/AST/Expr.cpp
    clang/lib/AST/ExprConstant.cpp
    clang/lib/AST/FormatString.cpp
    clang/lib/AST/ItaniumCXXABI.cpp
    clang/lib/AST/ItaniumMangle.cpp
    clang/lib/AST/PrintfFormatString.cpp
    clang/lib/AST/ScanfFormatString.cpp
    clang/lib/AST/TemplateBase.cpp
    clang/lib/AST/Type.cpp
    clang/lib/AST/VTTBuilder.cpp
    clang/lib/CIR/CodeGen/CIRGenCall.cpp
    clang/lib/CIR/CodeGen/CIRGenClass.cpp
    clang/lib/CIR/CodeGen/CIRGenExpr.cpp
    clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp
    clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp
    clang/lib/CIR/CodeGen/CIRGenModule.cpp
    clang/lib/CIR/CodeGen/CIRGenTypes.cpp
    clang/lib/CodeGen/ABIInfoImpl.cpp
    clang/lib/CodeGen/CGCUDANV.cpp
    clang/lib/CodeGen/CGCXX.cpp
    clang/lib/CodeGen/CGCall.cpp
    clang/lib/CodeGen/CGClass.cpp
    clang/lib/CodeGen/CGDebugInfo.cpp
    clang/lib/CodeGen/CGDecl.cpp
    clang/lib/CodeGen/CGExpr.cpp
    clang/lib/CodeGen/CGExprAgg.cpp
    clang/lib/CodeGen/CGExprCXX.cpp
    clang/lib/CodeGen/CGExprConstant.cpp
    clang/lib/CodeGen/CGExprScalar.cpp
    clang/lib/CodeGen/CGNonTrivialStruct.cpp
    clang/lib/CodeGen/CGOpenMPRuntime.cpp
    clang/lib/CodeGen/CodeGenTypes.cpp
    clang/lib/CodeGen/HLSLBufferLayoutBuilder.cpp
    clang/lib/CodeGen/ItaniumCXXABI.cpp
    clang/lib/CodeGen/Targets/AArch64.cpp
    clang/lib/CodeGen/Targets/ARC.cpp
    clang/lib/CodeGen/Targets/ARM.cpp
    clang/lib/CodeGen/Targets/BPF.cpp
    clang/lib/CodeGen/Targets/CSKY.cpp
    clang/lib/CodeGen/Targets/Hexagon.cpp
    clang/lib/CodeGen/Targets/Lanai.cpp
    clang/lib/CodeGen/Targets/LoongArch.cpp
    clang/lib/CodeGen/Targets/Mips.cpp
    clang/lib/CodeGen/Targets/NVPTX.cpp
    clang/lib/CodeGen/Targets/PPC.cpp
    clang/lib/CodeGen/Targets/RISCV.cpp
    clang/lib/CodeGen/Targets/Sparc.cpp
    clang/lib/CodeGen/Targets/SystemZ.cpp
    clang/lib/CodeGen/Targets/WebAssembly.cpp
    clang/lib/CodeGen/Targets/X86.cpp
    clang/lib/Frontend/Rewrite/RewriteModernObjC.cpp
    clang/lib/Frontend/Rewrite/RewriteObjC.cpp
    clang/lib/Index/IndexTypeSourceInfo.cpp
    clang/lib/Interpreter/InterpreterValuePrinter.cpp
    clang/lib/Interpreter/Value.cpp
    clang/lib/Sema/SemaAccess.cpp
    clang/lib/Sema/SemaBPF.cpp
    clang/lib/Sema/SemaCXXScopeSpec.cpp
    clang/lib/Sema/SemaCast.cpp
    clang/lib/Sema/SemaChecking.cpp
    clang/lib/Sema/SemaCodeComplete.cpp
    clang/lib/Sema/SemaDecl.cpp
    clang/lib/Sema/SemaDeclCXX.cpp
    clang/lib/Sema/SemaExpr.cpp
    clang/lib/Sema/SemaExprCXX.cpp
    clang/lib/Sema/SemaExprObjC.cpp
    clang/lib/Sema/SemaHLSL.cpp
    clang/lib/Sema/SemaInit.cpp
    clang/lib/Sema/SemaLambda.cpp
    clang/lib/Sema/SemaLookup.cpp
    clang/lib/Sema/SemaOpenMP.cpp
    clang/lib/Sema/SemaOverload.cpp
    clang/lib/Sema/SemaPPC.cpp
    clang/lib/Sema/SemaStmt.cpp
    clang/lib/Sema/SemaTemplate.cpp
    clang/lib/Sema/SemaType.cpp
    clang/lib/StaticAnalyzer/Checkers/EnumCastOutOfRangeChecker.cpp
    clang/lib/StaticAnalyzer/Core/RegionStore.cpp

Removed: 
    


################################################################################
diff  --git a/clang-tools-extra/clang-tidy/bugprone/TaggedUnionMemberCountCheck.cpp b/clang-tools-extra/clang-tidy/bugprone/TaggedUnionMemberCountCheck.cpp
index ddbb14e3ac62b..02f4421efdbf4 100644
--- a/clang-tools-extra/clang-tidy/bugprone/TaggedUnionMemberCountCheck.cpp
+++ b/clang-tools-extra/clang-tidy/bugprone/TaggedUnionMemberCountCheck.cpp
@@ -169,15 +169,8 @@ void TaggedUnionMemberCountCheck::check(
   if (!Root || !UnionField || !TagField)
     return;
 
-  const auto *UnionDef =
-      UnionField->getType().getCanonicalType().getTypePtr()->getAsRecordDecl();
-  const auto *EnumDef = llvm::dyn_cast<EnumDecl>(
-      TagField->getType().getCanonicalType().getTypePtr()->getAsTagDecl());
-
-  assert(UnionDef && "UnionDef is missing!");
-  assert(EnumDef && "EnumDef is missing!");
-  if (!UnionDef || !EnumDef)
-    return;
+  const auto *UnionDef = UnionField->getType()->castAsRecordDecl();
+  const auto *EnumDef = TagField->getType()->castAsEnumDecl();
 
   const std::size_t UnionMemberCount = llvm::range_size(UnionDef->fields());
   auto [TagCount, CountingEnumConstantDecl] = getNumberOfEnumValues(EnumDef);

diff  --git a/clang-tools-extra/clang-tidy/utils/ExceptionSpecAnalyzer.cpp b/clang-tools-extra/clang-tidy/utils/ExceptionSpecAnalyzer.cpp
index aa6aefcf0c493..4314817e4f69d 100644
--- a/clang-tools-extra/clang-tidy/utils/ExceptionSpecAnalyzer.cpp
+++ b/clang-tools-extra/clang-tidy/utils/ExceptionSpecAnalyzer.cpp
@@ -66,10 +66,7 @@ ExceptionSpecAnalyzer::analyzeBase(const CXXBaseSpecifier &Base,
   if (!RecType)
     return State::Unknown;
 
-  const auto *BaseClass =
-      cast<CXXRecordDecl>(RecType->getOriginalDecl())->getDefinitionOrSelf();
-
-  return analyzeRecord(BaseClass, Kind);
+  return analyzeRecord(RecType->getAsCXXRecordDecl(), Kind);
 }
 
 ExceptionSpecAnalyzer::State

diff  --git a/clang-tools-extra/clang-tidy/utils/FormatStringConverter.cpp b/clang-tools-extra/clang-tidy/utils/FormatStringConverter.cpp
index 0df8e913100fc..0d0834dc38fc6 100644
--- a/clang-tools-extra/clang-tidy/utils/FormatStringConverter.cpp
+++ b/clang-tools-extra/clang-tidy/utils/FormatStringConverter.cpp
@@ -460,10 +460,9 @@ bool FormatStringConverter::emitIntegerArgument(
     // be passed as its underlying type. However, printf will have forced
     // the signedness based on the format string, so we need to do the
     // same.
-    if (const auto *ET = ArgType->getAs<EnumType>()) {
-      if (const std::optional<std::string> MaybeCastType = castTypeForArgument(
-              ArgKind,
-              ET->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType()))
+    if (const auto *ED = ArgType->getAsEnumDecl()) {
+      if (const std::optional<std::string> MaybeCastType =
+              castTypeForArgument(ArgKind, ED->getIntegerType()))
         ArgFixes.emplace_back(
             ArgIndex, (Twine("static_cast<") + *MaybeCastType + ">(").str());
       else

diff  --git a/clang-tools-extra/clangd/Hover.cpp b/clang-tools-extra/clangd/Hover.cpp
index af00c8948a215..30c70ac02205b 100644
--- a/clang-tools-extra/clangd/Hover.cpp
+++ b/clang-tools-extra/clangd/Hover.cpp
@@ -454,8 +454,7 @@ std::optional<std::string> printExprValue(const Expr *E,
       Constant.Val.getInt().getSignificantBits() <= 64) {
     // Compare to int64_t to avoid bit-width match requirements.
     int64_t Val = Constant.Val.getInt().getExtValue();
-    for (const EnumConstantDecl *ECD :
-         T->castAs<EnumType>()->getOriginalDecl()->enumerators())
+    for (const EnumConstantDecl *ECD : T->castAsEnumDecl()->enumerators())
       if (ECD->getInitVal() == Val)
         return llvm::formatv("{0} ({1})", ECD->getNameAsString(),
                              printHex(Constant.Val.getInt()))
@@ -832,7 +831,7 @@ std::optional<HoverInfo> getThisExprHoverContents(const CXXThisExpr *CTE,
                                                   ASTContext &ASTCtx,
                                                   const PrintingPolicy &PP) {
   QualType OriginThisType = CTE->getType()->getPointeeType();
-  QualType ClassType = declaredType(OriginThisType->getAsTagDecl());
+  QualType ClassType = declaredType(OriginThisType->castAsTagDecl());
   // For partial specialization class, origin `this` pointee type will be
   // parsed as `InjectedClassNameType`, which will ouput template arguments
   // like "type-parameter-0-0". So we retrieve user written class type in this

diff  --git a/clang/include/clang/AST/Decl.h b/clang/include/clang/AST/Decl.h
index bebbde3661a33..79636a67dafba 100644
--- a/clang/include/clang/AST/Decl.h
+++ b/clang/include/clang/AST/Decl.h
@@ -3915,6 +3915,10 @@ class TagDecl : public TypeDecl,
   bool isUnion() const { return getTagKind() == TagTypeKind::Union; }
   bool isEnum() const { return getTagKind() == TagTypeKind::Enum; }
 
+  bool isStructureOrClass() const {
+    return isStruct() || isClass() || isInterface();
+  }
+
   /// Is this tag type named, either directly or via being defined in
   /// a typedef of this type?
   ///

diff  --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index adf5cb0462154..187e54f5cb54b 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -2883,14 +2883,21 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
   /// because the type is a RecordType or because it is the injected-class-name
   /// type of a class template or class template partial specialization.
   CXXRecordDecl *getAsCXXRecordDecl() const;
+  CXXRecordDecl *castAsCXXRecordDecl() const;
 
   /// Retrieves the RecordDecl this type refers to.
   RecordDecl *getAsRecordDecl() const;
+  RecordDecl *castAsRecordDecl() const;
+
+  /// Retrieves the EnumDecl this type refers to.
+  EnumDecl *getAsEnumDecl() const;
+  EnumDecl *castAsEnumDecl() const;
 
   /// Retrieves the TagDecl that this type refers to, either
   /// because the type is a TagType or because it is the injected-class-name
   /// type of a class template or class template partial specialization.
   TagDecl *getAsTagDecl() const;
+  TagDecl *castAsTagDecl() const;
 
   /// If this is a pointer or reference to a RecordType, return the
   /// CXXRecordDecl that the type refers to.

diff  --git a/clang/lib/AST/APValue.cpp b/clang/lib/AST/APValue.cpp
index 2d62209bbc28c..7173c2a0e1a2a 100644
--- a/clang/lib/AST/APValue.cpp
+++ b/clang/lib/AST/APValue.cpp
@@ -903,8 +903,7 @@ void APValue::printPretty(raw_ostream &Out, const PrintingPolicy &Policy,
   case APValue::Struct: {
     Out << '{';
     bool First = true;
-    const RecordDecl *RD =
-        Ty->castAs<RecordType>()->getOriginalDecl()->getDefinitionOrSelf();
+    const auto *RD = Ty->castAsRecordDecl();
     if (unsigned N = getStructNumBases()) {
       const CXXRecordDecl *CD = cast<CXXRecordDecl>(RD);
       CXXRecordDecl::base_class_const_iterator BI = CD->bases_begin();

diff  --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index 5fc55b2675fd2..06e7a2d5b857b 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -2001,8 +2001,7 @@ bool ASTContext::isPromotableIntegerType(QualType T) const {
 
   // Enumerated types are promotable to their compatible integer types
   // (C99 6.3.1.1) a.k.a. its underlying type (C++ [conv.prom]p2).
-  if (const auto *ET = T->getAs<EnumType>()) {
-    const EnumDecl *ED = ET->getOriginalDecl()->getDefinitionOrSelf();
+  if (const auto *ED = T->getAsEnumDecl()) {
     if (T->isDependentType() || ED->getPromotionType().isNull() ||
         ED->isScoped())
       return false;
@@ -2712,11 +2711,8 @@ unsigned ASTContext::getPreferredTypeAlign(const Type *T) const {
   // possible.
   if (const auto *CT = T->getAs<ComplexType>())
     T = CT->getElementType().getTypePtr();
-  if (const auto *ET = T->getAs<EnumType>())
-    T = ET->getOriginalDecl()
-            ->getDefinitionOrSelf()
-            ->getIntegerType()
-            .getTypePtr();
+  if (const auto *ED = T->getAsEnumDecl())
+    T = ED->getIntegerType().getTypePtr();
   if (T->isSpecificBuiltinType(BuiltinType::Double) ||
       T->isSpecificBuiltinType(BuiltinType::LongLong) ||
       T->isSpecificBuiltinType(BuiltinType::ULongLong) ||
@@ -3412,10 +3408,7 @@ static void encodeTypeForFunctionPointerAuth(const ASTContext &Ctx,
     //   type, or an unsigned integer type.
     //
     // So we have to treat enum types as integers.
-    QualType UnderlyingType = cast<EnumType>(T)
-                                  ->getOriginalDecl()
-                                  ->getDefinitionOrSelf()
-                                  ->getIntegerType();
+    QualType UnderlyingType = T->castAsEnumDecl()->getIntegerType();
     return encodeTypeForFunctionPointerAuth(
         Ctx, OS, UnderlyingType.isNull() ? Ctx.IntTy : UnderlyingType);
   }
@@ -8351,8 +8344,8 @@ QualType ASTContext::isPromotableBitField(Expr *E) const {
 QualType ASTContext::getPromotedIntegerType(QualType Promotable) const {
   assert(!Promotable.isNull());
   assert(isPromotableIntegerType(Promotable));
-  if (const auto *ET = Promotable->getAs<EnumType>())
-    return ET->getOriginalDecl()->getDefinitionOrSelf()->getPromotionType();
+  if (const auto *ED = Promotable->getAsEnumDecl())
+    return ED->getPromotionType();
 
   if (const auto *BT = Promotable->getAs<BuiltinType>()) {
     // C++ [conv.prom]: A prvalue of type char16_t, char32_t, or wchar_t
@@ -8571,10 +8564,9 @@ QualType ASTContext::getObjCSuperType() const {
 }
 
 void ASTContext::setCFConstantStringType(QualType T) {
-  const auto *TD = T->castAs<TypedefType>();
-  CFConstantStringTypeDecl = cast<TypedefDecl>(TD->getDecl());
-  const auto *TagType = TD->castAs<RecordType>();
-  CFConstantStringTagDecl = TagType->getOriginalDecl()->getDefinitionOrSelf();
+  const auto *TT = T->castAs<TypedefType>();
+  CFConstantStringTypeDecl = cast<TypedefDecl>(TT->getDecl());
+  CFConstantStringTagDecl = TT->castAsRecordDecl();
 }
 
 QualType ASTContext::getBlockDescriptorType() const {
@@ -11667,9 +11659,8 @@ QualType ASTContext::mergeFunctionTypes(QualType lhs, QualType rhs,
 
       // Look at the converted type of enum types, since that is the type used
       // to pass enum values.
-      if (const auto *Enum = paramTy->getAs<EnumType>()) {
-        paramTy =
-            Enum->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+      if (const auto *ED = paramTy->getAsEnumDecl()) {
+        paramTy = ED->getIntegerType();
         if (paramTy.isNull())
           return {};
       }
@@ -12260,8 +12251,8 @@ QualType ASTContext::mergeObjCGCQualifiers(QualType LHS, QualType RHS) {
 //===----------------------------------------------------------------------===//
 
 unsigned ASTContext::getIntWidth(QualType T) const {
-  if (const auto *ET = T->getAs<EnumType>())
-    T = ET->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = T->getAsEnumDecl())
+    T = ED->getIntegerType();
   if (T->isBooleanType())
     return 1;
   if (const auto *EIT = T->getAs<BitIntType>())
@@ -12286,8 +12277,8 @@ QualType ASTContext::getCorrespondingUnsignedType(QualType T) const {
 
   // For enums, get the underlying integer type of the enum, and let the general
   // integer type signchanging code handle it.
-  if (const auto *ETy = T->getAs<EnumType>())
-    T = ETy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = T->getAsEnumDecl())
+    T = ED->getIntegerType();
 
   switch (T->castAs<BuiltinType>()->getKind()) {
   case BuiltinType::Char_U:
@@ -12360,8 +12351,8 @@ QualType ASTContext::getCorrespondingSignedType(QualType T) const {
 
   // For enums, get the underlying integer type of the enum, and let the general
   // integer type signchanging code handle it.
-  if (const auto *ETy = T->getAs<EnumType>())
-    T = ETy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = T->getAsEnumDecl())
+    T = ED->getIntegerType();
 
   switch (T->castAs<BuiltinType>()->getKind()) {
   case BuiltinType::Char_S:

diff  --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp
index 8610267ff6650..0079d937e885b 100644
--- a/clang/lib/AST/ByteCode/Compiler.cpp
+++ b/clang/lib/AST/ByteCode/Compiler.cpp
@@ -559,8 +559,7 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
     // Possibly diagnose casts to enum types if the target type does not
     // have a fixed size.
     if (Ctx.getLangOpts().CPlusPlus && CE->getType()->isEnumeralType()) {
-      const auto *ET = CE->getType().getCanonicalType()->castAs<EnumType>();
-      const auto *ED = ET->getOriginalDecl()->getDefinitionOrSelf();
+      const auto *ED = CE->getType()->castAsEnumDecl();
       if (!ED->isFixed()) {
         if (!this->emitCheckEnumValue(*FromT, ED, CE))
           return false;

diff  --git a/clang/lib/AST/ByteCode/Context.cpp b/clang/lib/AST/ByteCode/Context.cpp
index 36eb7607e70bf..fbbb508ed226c 100644
--- a/clang/lib/AST/ByteCode/Context.cpp
+++ b/clang/lib/AST/ByteCode/Context.cpp
@@ -364,8 +364,7 @@ OptPrimType Context::classify(QualType T) const {
     return integralTypeToPrimTypeU(BT->getNumBits());
   }
 
-  if (const auto *ET = T->getAs<EnumType>()) {
-    const auto *D = ET->getOriginalDecl()->getDefinitionOrSelf();
+  if (const auto *D = T->getAsEnumDecl()) {
     if (!D->isComplete())
       return std::nullopt;
     return classify(D->getIntegerType());

diff  --git a/clang/lib/AST/ByteCode/Program.cpp b/clang/lib/AST/ByteCode/Program.cpp
index d9403c25d598b..0be017ea59b91 100644
--- a/clang/lib/AST/ByteCode/Program.cpp
+++ b/clang/lib/AST/ByteCode/Program.cpp
@@ -347,11 +347,7 @@ Record *Program::getOrCreateRecord(const RecordDecl *RD) {
     }
 
     for (const CXXBaseSpecifier &Spec : CD->vbases()) {
-      const auto *RT = Spec.getType()->getAs<RecordType>();
-      if (!RT)
-        return nullptr;
-
-      const RecordDecl *BD = RT->getOriginalDecl()->getDefinitionOrSelf();
+      const auto *BD = Spec.getType()->castAsCXXRecordDecl();
       const Record *BR = getOrCreateRecord(BD);
 
       const Descriptor *Desc = GetBaseDesc(BD, BR);

diff  --git a/clang/lib/AST/CXXInheritance.cpp b/clang/lib/AST/CXXInheritance.cpp
index 0ced210900b1a..94f01c86a16ca 100644
--- a/clang/lib/AST/CXXInheritance.cpp
+++ b/clang/lib/AST/CXXInheritance.cpp
@@ -263,7 +263,7 @@ bool CXXBasePaths::lookupInBases(ASTContext &Context,
             BaseRecord = nullptr;
         }
       } else {
-        BaseRecord = cast<CXXRecordDecl>(BaseSpec.getType()->getAsRecordDecl());
+        BaseRecord = BaseSpec.getType()->castAsCXXRecordDecl();
       }
       if (BaseRecord &&
           lookupInBases(Context, BaseRecord, BaseMatches, LookupInDependent)) {
@@ -327,10 +327,7 @@ bool CXXRecordDecl::lookupInBases(BaseMatchesCallback BaseMatches,
       if (!PE.Base->isVirtual())
         continue;
 
-      CXXRecordDecl *VBase = nullptr;
-      if (const RecordType *Record = PE.Base->getType()->getAs<RecordType>())
-        VBase = cast<CXXRecordDecl>(Record->getOriginalDecl())
-                    ->getDefinitionOrSelf();
+      auto *VBase = PE.Base->getType()->getAsCXXRecordDecl();
       if (!VBase)
         break;
 
@@ -396,7 +393,7 @@ bool CXXRecordDecl::hasMemberName(DeclarationName Name) const {
   CXXBasePaths Paths(false, false, false);
   return lookupInBases(
       [Name](const CXXBaseSpecifier *Specifier, CXXBasePath &Path) {
-        return findOrdinaryMember(Specifier->getType()->getAsCXXRecordDecl(),
+        return findOrdinaryMember(Specifier->getType()->castAsCXXRecordDecl(),
                                   Path, Name);
       },
       Paths);

diff  --git a/clang/lib/AST/DeclCXX.cpp b/clang/lib/AST/DeclCXX.cpp
index 62eb4de8c6a96..86d3b136ce0b5 100644
--- a/clang/lib/AST/DeclCXX.cpp
+++ b/clang/lib/AST/DeclCXX.cpp
@@ -216,9 +216,7 @@ CXXRecordDecl::setBases(CXXBaseSpecifier const * const *Bases,
     // Skip dependent types; we can't do any checking on them now.
     if (BaseType->isDependentType())
       continue;
-    auto *BaseClassDecl =
-        cast<CXXRecordDecl>(BaseType->castAs<RecordType>()->getOriginalDecl())
-            ->getDefinitionOrSelf();
+    auto *BaseClassDecl = BaseType->castAsCXXRecordDecl();
 
     // C++2a [class]p7:
     //   A standard-layout class is a class that:
@@ -3432,13 +3430,12 @@ SourceRange UsingDecl::getSourceRange() const {
 void UsingEnumDecl::anchor() {}
 
 UsingEnumDecl *UsingEnumDecl::Create(ASTContext &C, DeclContext *DC,
-                                     SourceLocation UL,
-                                     SourceLocation EL,
+                                     SourceLocation UL, SourceLocation EL,
                                      SourceLocation NL,
                                      TypeSourceInfo *EnumType) {
-  assert(isa<EnumDecl>(EnumType->getType()->getAsTagDecl()));
   return new (C, DC)
-      UsingEnumDecl(DC, EnumType->getType()->getAsTagDecl()->getDeclName(), UL, EL, NL, EnumType);
+      UsingEnumDecl(DC, EnumType->getType()->castAsEnumDecl()->getDeclName(),
+                    UL, EL, NL, EnumType);
 }
 
 UsingEnumDecl *UsingEnumDecl::CreateDeserialized(ASTContext &C,

diff  --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp
index 9d1490c2ef834..072d07cb81179 100644
--- a/clang/lib/AST/Expr.cpp
+++ b/clang/lib/AST/Expr.cpp
@@ -74,8 +74,7 @@ const CXXRecordDecl *Expr::getBestDynamicClassType() const {
   if (DerivedType->isDependentType())
     return nullptr;
 
-  const RecordType *Ty = DerivedType->castAs<RecordType>();
-  return cast<CXXRecordDecl>(Ty->getOriginalDecl())->getDefinitionOrSelf();
+  return DerivedType->castAsCXXRecordDecl();
 }
 
 const Expr *Expr::skipRValueSubobjectAdjustments(
@@ -90,10 +89,7 @@ const Expr *Expr::skipRValueSubobjectAdjustments(
            CE->getCastKind() == CK_UncheckedDerivedToBase) &&
           E->getType()->isRecordType()) {
         E = CE->getSubExpr();
-        const auto *Derived =
-            cast<CXXRecordDecl>(
-                E->getType()->castAs<RecordType>()->getOriginalDecl())
-                ->getDefinitionOrSelf();
+        const auto *Derived = E->getType()->castAsCXXRecordDecl();
         Adjustments.push_back(SubobjectAdjustment(CE, Derived));
         continue;
       }
@@ -2032,9 +2028,7 @@ CXXBaseSpecifier **CastExpr::path_buffer() {
 
 const FieldDecl *CastExpr::getTargetFieldForToUnionCast(QualType unionType,
                                                         QualType opType) {
-  auto RD =
-      unionType->castAs<RecordType>()->getOriginalDecl()->getDefinitionOrSelf();
-  return getTargetFieldForToUnionCast(RD, opType);
+  return getTargetFieldForToUnionCast(unionType->castAsRecordDecl(), opType);
 }
 
 const FieldDecl *CastExpr::getTargetFieldForToUnionCast(const RecordDecl *RD,
@@ -3396,10 +3390,7 @@ bool Expr::isConstantInitializer(ASTContext &Ctx, bool IsForRef,
 
     if (ILE->getType()->isRecordType()) {
       unsigned ElementNo = 0;
-      RecordDecl *RD = ILE->getType()
-                           ->castAs<RecordType>()
-                           ->getOriginalDecl()
-                           ->getDefinitionOrSelf();
+      auto *RD = ILE->getType()->castAsRecordDecl();
 
       // In C++17, bases were added to the list of members used by aggregate
       // initialization.

diff  --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 4ce04d0227ee3..19703e40d2696 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -2614,8 +2614,7 @@ static bool CheckEvaluationResult(CheckEvaluationResultKind CERK,
         Value.getUnionValue(), Kind, Value.getUnionField(), CheckedTemps);
   }
   if (Value.isStruct()) {
-    RecordDecl *RD =
-        Type->castAs<RecordType>()->getOriginalDecl()->getDefinitionOrSelf();
+    auto *RD = Type->castAsRecordDecl();
     if (const CXXRecordDecl *CD = dyn_cast<CXXRecordDecl>(RD)) {
       unsigned BaseIndex = 0;
       for (const CXXBaseSpecifier &BS : CD->bases()) {
@@ -10769,8 +10768,7 @@ static bool HandleClassZeroInitialization(EvalInfo &Info, const Expr *E,
 }
 
 bool RecordExprEvaluator::ZeroInitialization(const Expr *E, QualType T) {
-  const RecordDecl *RD =
-      T->castAs<RecordType>()->getOriginalDecl()->getDefinitionOrSelf();
+  const auto *RD = T->castAsRecordDecl();
   if (RD->isInvalidDecl()) return false;
   if (RD->isUnion()) {
     // C++11 [dcl.init]p5: If T is a (possibly cv-qualified) union type, the
@@ -10839,10 +10837,7 @@ bool RecordExprEvaluator::VisitInitListExpr(const InitListExpr *E) {
 
 bool RecordExprEvaluator::VisitCXXParenListOrInitListExpr(
     const Expr *ExprToVisit, ArrayRef<Expr *> Args) {
-  const RecordDecl *RD = ExprToVisit->getType()
-                             ->castAs<RecordType>()
-                             ->getOriginalDecl()
-                             ->getDefinitionOrSelf();
+  const auto *RD = ExprToVisit->getType()->castAsRecordDecl();
   if (RD->isInvalidDecl()) return false;
   const ASTRecordLayout &Layout = Info.Ctx.getASTRecordLayout(RD);
   auto *CXXRD = dyn_cast<CXXRecordDecl>(RD);
@@ -11066,10 +11061,7 @@ bool RecordExprEvaluator::VisitCXXStdInitializerListExpr(
   Result = APValue(APValue::UninitStruct(), 0, 2);
   Array.moveInto(Result.getStructField(0));
 
-  RecordDecl *Record = E->getType()
-                           ->castAs<RecordType>()
-                           ->getOriginalDecl()
-                           ->getDefinitionOrSelf();
+  auto *Record = E->getType()->castAsRecordDecl();
   RecordDecl::field_iterator Field = Record->field_begin();
   assert(Field != Record->field_end() &&
          Info.Ctx.hasSameType(Field->getType()->getPointeeType(),
@@ -13131,10 +13123,7 @@ static bool convertUnsignedAPIntToCharUnits(const llvm::APInt &Int,
 static void addFlexibleArrayMemberInitSize(EvalInfo &Info, const QualType &T,
                                            const LValue &LV, CharUnits &Size) {
   if (!T.isNull() && T->isStructureType() &&
-      T->getAsStructureType()
-          ->getOriginalDecl()
-          ->getDefinitionOrSelf()
-          ->hasFlexibleArrayMember())
+      T->castAsRecordDecl()->hasFlexibleArrayMember())
     if (const auto *V = LV.getLValueBase().dyn_cast<const ValueDecl *>())
       if (const auto *VD = dyn_cast<VarDecl>(V))
         if (VD->hasInit())
@@ -15655,8 +15644,7 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
     }
 
     if (Info.Ctx.getLangOpts().CPlusPlus && DestType->isEnumeralType()) {
-      const EnumType *ET = dyn_cast<EnumType>(DestType.getCanonicalType());
-      const EnumDecl *ED = ET->getOriginalDecl()->getDefinitionOrSelf();
+      const auto *ED = DestType->getAsEnumDecl();
       // Check that the value is within the range of the enumeration values.
       //
       // This corressponds to [expr.static.cast]p10 which says:

diff  --git a/clang/lib/AST/FormatString.cpp b/clang/lib/AST/FormatString.cpp
index a73c02734e558..d4cb89b43ae87 100644
--- a/clang/lib/AST/FormatString.cpp
+++ b/clang/lib/AST/FormatString.cpp
@@ -413,14 +413,13 @@ ArgType::matchesType(ASTContext &C, QualType argTy) const {
       return Match;
 
     case AnyCharTy: {
-      if (const auto *ETy = argTy->getAs<EnumType>()) {
+      if (const auto *ED = argTy->getAsEnumDecl()) {
         // If the enum is incomplete we know nothing about the underlying type.
         // Assume that it's 'int'. Do not use the underlying type for a scoped
         // enumeration.
-        const EnumDecl *ED = ETy->getOriginalDecl()->getDefinitionOrSelf();
         if (!ED->isComplete())
           return NoMatch;
-        if (ETy->isUnscopedEnumerationType())
+        if (!ED->isScoped())
           argTy = ED->getIntegerType();
       }
 
@@ -463,14 +462,13 @@ ArgType::matchesType(ASTContext &C, QualType argTy) const {
         return matchesSizeTPtr
diff T(C, argTy, T);
       }
 
-      if (const EnumType *ETy = argTy->getAs<EnumType>()) {
+      if (const auto *ED = argTy->getAsEnumDecl()) {
         // If the enum is incomplete we know nothing about the underlying type.
         // Assume that it's 'int'. Do not use the underlying type for a scoped
         // enumeration as that needs an exact match.
-        const EnumDecl *ED = ETy->getOriginalDecl()->getDefinitionOrSelf();
         if (!ED->isComplete())
           argTy = C.IntTy;
-        else if (ETy->isUnscopedEnumerationType())
+        else if (!ED->isScoped())
           argTy = ED->getIntegerType();
       }
 

diff  --git a/clang/lib/AST/ItaniumCXXABI.cpp b/clang/lib/AST/ItaniumCXXABI.cpp
index 43a8bcd9443ff..adef1584fd9b6 100644
--- a/clang/lib/AST/ItaniumCXXABI.cpp
+++ b/clang/lib/AST/ItaniumCXXABI.cpp
@@ -42,8 +42,7 @@ namespace {
 ///
 /// Returns the name of anonymous union VarDecl or nullptr if it is not found.
 static const IdentifierInfo *findAnonymousUnionVarDeclName(const VarDecl& VD) {
-  const auto *RT = VD.getType()->castAs<RecordType>();
-  const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
+  const auto *RD = VD.getType()->castAsRecordDecl();
   assert(RD->isUnion() && "RecordType is expected to be a union.");
   if (const FieldDecl *FD = RD->findFirstNamedDataMember()) {
     return FD->getIdentifier();

diff  --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp
index c7eaad47710df..1ba224b74a606 100644
--- a/clang/lib/AST/ItaniumMangle.cpp
+++ b/clang/lib/AST/ItaniumMangle.cpp
@@ -1580,10 +1580,7 @@ void CXXNameMangler::mangleUnqualifiedName(
 
     if (const VarDecl *VD = dyn_cast<VarDecl>(ND)) {
       // We must have an anonymous union or struct declaration.
-      const RecordDecl *RD = VD->getType()
-                                 ->castAs<RecordType>()
-                                 ->getOriginalDecl()
-                                 ->getDefinitionOrSelf();
+      const auto *RD = VD->getType()->castAsRecordDecl();
 
       // Itanium C++ ABI 5.1.2:
       //

diff  --git a/clang/lib/AST/PrintfFormatString.cpp b/clang/lib/AST/PrintfFormatString.cpp
index 687160c6116be..855550475721a 100644
--- a/clang/lib/AST/PrintfFormatString.cpp
+++ b/clang/lib/AST/PrintfFormatString.cpp
@@ -793,8 +793,8 @@ bool PrintfSpecifier::fixType(QualType QT, const LangOptions &LangOpt,
   }
 
   // If it's an enum, get its underlying type.
-  if (const EnumType *ETy = QT->getAs<EnumType>())
-    QT = ETy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = QT->getAsEnumDecl())
+    QT = ED->getIntegerType();
 
   const BuiltinType *BT = QT->getAs<BuiltinType>();
   if (!BT) {

diff  --git a/clang/lib/AST/ScanfFormatString.cpp b/clang/lib/AST/ScanfFormatString.cpp
index 31c001d025fea..41cf71a3e042d 100644
--- a/clang/lib/AST/ScanfFormatString.cpp
+++ b/clang/lib/AST/ScanfFormatString.cpp
@@ -430,9 +430,8 @@ bool ScanfSpecifier::fixType(QualType QT, QualType RawQT,
   QualType PT = QT->getPointeeType();
 
   // If it's an enum, get its underlying type.
-  if (const EnumType *ETy = PT->getAs<EnumType>()) {
+  if (const auto *ED = PT->getAsEnumDecl()) {
     // Don't try to fix incomplete enums.
-    const EnumDecl *ED = ETy->getOriginalDecl()->getDefinitionOrSelf();
     if (!ED->isComplete())
       return false;
     PT = ED->getIntegerType();

diff  --git a/clang/lib/AST/TemplateBase.cpp b/clang/lib/AST/TemplateBase.cpp
index 76050ceeb35a7..76f96fb8c5dcc 100644
--- a/clang/lib/AST/TemplateBase.cpp
+++ b/clang/lib/AST/TemplateBase.cpp
@@ -56,8 +56,8 @@ static void printIntegral(const TemplateArgument &TemplArg, raw_ostream &Out,
   const llvm::APSInt &Val = TemplArg.getAsIntegral();
 
   if (Policy.UseEnumerators) {
-    if (const EnumType *ET = T->getAs<EnumType>()) {
-      for (const EnumConstantDecl *ECD : ET->getOriginalDecl()->enumerators()) {
+    if (const auto *ED = T->getAsEnumDecl()) {
+      for (const EnumConstantDecl *ECD : ED->enumerators()) {
         // In Sema::CheckTemplateArugment, enum template arguments value are
         // extended to the size of the integer underlying the enum type.  This
         // may create a size 
diff erence between the enum value and template

diff  --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp
index 63d808f631591..b0f865bab5d77 100644
--- a/clang/lib/AST/Type.cpp
+++ b/clang/lib/AST/Type.cpp
@@ -685,17 +685,15 @@ bool Type::isStructureTypeWithFlexibleArrayMember() const {
   const auto *RT = getAs<RecordType>();
   if (!RT)
     return false;
-  const auto *Decl = RT->getOriginalDecl()->getDefinitionOrSelf();
+  const auto *Decl = RT->getOriginalDecl();
   if (!Decl->isStruct())
     return false;
-  return Decl->hasFlexibleArrayMember();
+  return Decl->getDefinitionOrSelf()->hasFlexibleArrayMember();
 }
 
 bool Type::isObjCBoxableRecordType() const {
-  if (const auto *RT = getAs<RecordType>())
-    return RT->getOriginalDecl()
-        ->getDefinitionOrSelf()
-        ->hasAttr<ObjCBoxableAttr>();
+  if (const auto *RD = getAsRecordDecl())
+    return RD->hasAttr<ObjCBoxableAttr>();
   return false;
 }
 
@@ -1918,12 +1916,7 @@ const CXXRecordDecl *Type::getPointeeCXXRecordDecl() const {
     PointeeType = RT->getPointeeType();
   else
     return nullptr;
-
-  if (const auto *RT = PointeeType->getAs<RecordType>())
-    return dyn_cast<CXXRecordDecl>(
-        RT->getOriginalDecl()->getDefinitionOrSelf());
-
-  return nullptr;
+  return PointeeType->getAsCXXRecordDecl();
 }
 
 CXXRecordDecl *Type::getAsCXXRecordDecl() const {
@@ -1936,6 +1929,11 @@ CXXRecordDecl *Type::getAsCXXRecordDecl() const {
   return cast<CXXRecordDecl>(TD)->getDefinitionOrSelf();
 }
 
+CXXRecordDecl *Type::castAsCXXRecordDecl() const {
+  const auto *TT = cast<TagType>(CanonicalType);
+  return cast<CXXRecordDecl>(TT->getOriginalDecl())->getDefinitionOrSelf();
+}
+
 RecordDecl *Type::getAsRecordDecl() const {
   const auto *TT = dyn_cast<TagType>(CanonicalType);
   if (!isa_and_present<RecordType, InjectedClassNameType>(TT))
@@ -1943,12 +1941,33 @@ RecordDecl *Type::getAsRecordDecl() const {
   return cast<RecordDecl>(TT->getOriginalDecl())->getDefinitionOrSelf();
 }
 
+RecordDecl *Type::castAsRecordDecl() const {
+  const auto *TT = cast<TagType>(CanonicalType);
+  return cast<RecordDecl>(TT->getOriginalDecl())->getDefinitionOrSelf();
+}
+
+EnumDecl *Type::getAsEnumDecl() const {
+  if (const auto *TT = dyn_cast<EnumType>(CanonicalType))
+    return TT->getOriginalDecl()->getDefinitionOrSelf();
+  return nullptr;
+}
+
+EnumDecl *Type::castAsEnumDecl() const {
+  return cast<EnumType>(CanonicalType)
+      ->getOriginalDecl()
+      ->getDefinitionOrSelf();
+}
+
 TagDecl *Type::getAsTagDecl() const {
   if (const auto *TT = dyn_cast<TagType>(CanonicalType))
     return TT->getOriginalDecl()->getDefinitionOrSelf();
   return nullptr;
 }
 
+TagDecl *Type::castAsTagDecl() const {
+  return cast<TagType>(CanonicalType)->getOriginalDecl()->getDefinitionOrSelf();
+}
+
 const TemplateSpecializationType *
 Type::getAsNonAliasTemplateSpecializationType() const {
   const auto *TST = getAs<TemplateSpecializationType>();
@@ -2242,10 +2261,9 @@ bool Type::isSignedIntegerType() const {
   if (const auto *BT = dyn_cast<BuiltinType>(CanonicalType))
     return BT->isSignedInteger();
 
-  if (const EnumType *ET = dyn_cast<EnumType>(CanonicalType)) {
+  if (const auto *ED = getAsEnumDecl()) {
     // Incomplete enum types are not treated as integer types.
     // FIXME: In C++, enum types are never integer types.
-    const auto *ED = ET->getOriginalDecl()->getDefinitionOrSelf();
     if (!ED->isComplete() || ED->isScoped())
       return false;
     return ED->getIntegerType()->isSignedIntegerType();
@@ -2263,8 +2281,7 @@ bool Type::isSignedIntegerOrEnumerationType() const {
   if (const auto *BT = dyn_cast<BuiltinType>(CanonicalType))
     return BT->isSignedInteger();
 
-  if (const auto *ET = dyn_cast<EnumType>(CanonicalType)) {
-    const auto *ED = ET->getOriginalDecl()->getDefinitionOrSelf();
+  if (const auto *ED = getAsEnumDecl()) {
     if (!ED->isComplete())
       return false;
     return ED->getIntegerType()->isSignedIntegerType();
@@ -2292,10 +2309,9 @@ bool Type::isUnsignedIntegerType() const {
   if (const auto *BT = dyn_cast<BuiltinType>(CanonicalType))
     return BT->isUnsignedInteger();
 
-  if (const auto *ET = dyn_cast<EnumType>(CanonicalType)) {
+  if (const auto *ED = getAsEnumDecl()) {
     // Incomplete enum types are not treated as integer types.
     // FIXME: In C++, enum types are never integer types.
-    const auto *ED = ET->getOriginalDecl()->getDefinitionOrSelf();
     if (!ED->isComplete() || ED->isScoped())
       return false;
     return ED->getIntegerType()->isUnsignedIntegerType();
@@ -2313,8 +2329,7 @@ bool Type::isUnsignedIntegerOrEnumerationType() const {
   if (const auto *BT = dyn_cast<BuiltinType>(CanonicalType))
     return BT->isUnsignedInteger();
 
-  if (const auto *ET = dyn_cast<EnumType>(CanonicalType)) {
-    const auto *ED = ET->getOriginalDecl()->getDefinitionOrSelf();
+  if (const auto *ED = getAsEnumDecl()) {
     if (!ED->isComplete())
       return false;
     return ED->getIntegerType()->isUnsignedIntegerType();
@@ -2394,10 +2409,8 @@ bool Type::isArithmeticType() const {
 bool Type::hasBooleanRepresentation() const {
   if (const auto *VT = dyn_cast<VectorType>(CanonicalType))
     return VT->getElementType()->isBooleanType();
-  if (const auto *ET = dyn_cast<EnumType>(CanonicalType)) {
-    const auto *ED = ET->getOriginalDecl()->getDefinitionOrSelf();
+  if (const auto *ED = getAsEnumDecl())
     return ED->isComplete() && ED->getIntegerType()->isBooleanType();
-  }
   if (const auto *IT = dyn_cast<BitIntType>(CanonicalType))
     return IT->getNumBits() == 1;
   return isBooleanType();
@@ -2428,10 +2441,7 @@ Type::ScalarTypeKind Type::getScalarTypeKind() const {
   } else if (isa<MemberPointerType>(T)) {
     return STK_MemberPointer;
   } else if (isa<EnumType>(T)) {
-    assert(cast<EnumType>(T)
-               ->getOriginalDecl()
-               ->getDefinitionOrSelf()
-               ->isComplete());
+    assert(T->castAsEnumDecl()->isComplete());
     return STK_Integral;
   } else if (const auto *CT = dyn_cast<ComplexType>(T)) {
     if (CT->getElementType()->isRealFloatingType())
@@ -2490,8 +2500,7 @@ bool Type::isIncompleteType(NamedDecl **Def) const {
     // be completed.
     return isVoidType();
   case Enum: {
-    EnumDecl *EnumD =
-        cast<EnumType>(CanonicalType)->getOriginalDecl()->getDefinitionOrSelf();
+    auto *EnumD = castAsEnumDecl();
     if (Def)
       *Def = EnumD;
     return !EnumD->isComplete();
@@ -2499,17 +2508,13 @@ bool Type::isIncompleteType(NamedDecl **Def) const {
   case Record: {
     // A tagged type (struct/union/enum/class) is incomplete if the decl is a
     // forward declaration, but not a full definition (C99 6.2.5p22).
-    RecordDecl *Rec = cast<RecordType>(CanonicalType)
-                          ->getOriginalDecl()
-                          ->getDefinitionOrSelf();
+    auto *Rec = castAsRecordDecl();
     if (Def)
       *Def = Rec;
     return !Rec->isCompleteDefinition();
   }
   case InjectedClassName: {
-    CXXRecordDecl *Rec = cast<InjectedClassNameType>(CanonicalType)
-                             ->getOriginalDecl()
-                             ->getDefinitionOrSelf();
+    auto *Rec = castAsCXXRecordDecl();
     if (!Rec->isBeingDefined())
       return false;
     if (Def)

diff  --git a/clang/lib/AST/VTTBuilder.cpp b/clang/lib/AST/VTTBuilder.cpp
index 85101aee97e66..89b58b557ddca 100644
--- a/clang/lib/AST/VTTBuilder.cpp
+++ b/clang/lib/AST/VTTBuilder.cpp
@@ -63,11 +63,7 @@ void VTTBuilder::LayoutSecondaryVTTs(BaseSubobject Base) {
     if (I.isVirtual())
         continue;
 
-    const auto *BaseDecl =
-        cast<CXXRecordDecl>(
-            I.getType()->castAs<RecordType>()->getOriginalDecl())
-            ->getDefinitionOrSelf();
-
+    const auto *BaseDecl = I.getType()->castAsCXXRecordDecl();
     const ASTRecordLayout &Layout = Ctx.getASTRecordLayout(RD);
     CharUnits BaseOffset = Base.getBaseOffset() +
       Layout.getBaseClassOffset(BaseDecl);
@@ -91,10 +87,7 @@ VTTBuilder::LayoutSecondaryVirtualPointers(BaseSubobject Base,
     return;
 
   for (const auto &I : RD->bases()) {
-    const auto *BaseDecl =
-        cast<CXXRecordDecl>(
-            I.getType()->castAs<RecordType>()->getOriginalDecl())
-            ->getDefinitionOrSelf();
+    const auto *BaseDecl = I.getType()->castAsCXXRecordDecl();
 
     // Itanium C++ ABI 2.6.2:
     //   Secondary virtual pointers are present for all bases with either
@@ -157,10 +150,7 @@ VTTBuilder::LayoutSecondaryVirtualPointers(BaseSubobject Base,
 void VTTBuilder::LayoutVirtualVTTs(const CXXRecordDecl *RD,
                                    VisitedVirtualBasesSetTy &VBases) {
   for (const auto &I : RD->bases()) {
-    const auto *BaseDecl =
-        cast<CXXRecordDecl>(
-            I.getType()->castAs<RecordType>()->getOriginalDecl())
-            ->getDefinitionOrSelf();
+    const auto *BaseDecl = I.getType()->castAsCXXRecordDecl();
 
     // Check if this is a virtual base.
     if (I.isVirtual()) {

diff  --git a/clang/lib/CIR/CodeGen/CIRGenCall.cpp b/clang/lib/CIR/CodeGen/CIRGenCall.cpp
index 8a15e5f96aea2..25859885296fa 100644
--- a/clang/lib/CIR/CodeGen/CIRGenCall.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenCall.cpp
@@ -287,10 +287,7 @@ void CIRGenFunction::emitDelegateCallArg(CallArgList &args,
   // Deactivate the cleanup for the callee-destructed param that was pushed.
   assert(!cir::MissingFeatures::thunks());
   if (type->isRecordType() &&
-      type->castAs<RecordType>()
-          ->getOriginalDecl()
-          ->getDefinitionOrSelf()
-          ->isParamDestroyedInCallee() &&
+      type->castAsRecordDecl()->isParamDestroyedInCallee() &&
       param->needsDestruction(getContext())) {
     cgm.errorNYI(param->getSourceRange(),
                  "emitDelegateCallArg: callee-destructed param");
@@ -690,10 +687,8 @@ void CIRGenFunction::emitCallArg(CallArgList &args, const clang::Expr *e,
   // In the Microsoft C++ ABI, aggregate arguments are destructed by the callee.
   // However, we still have to push an EH-only cleanup in case we unwind before
   // we make it to the call.
-  if (argType->isRecordType() && argType->castAs<RecordType>()
-                                     ->getOriginalDecl()
-                                     ->getDefinitionOrSelf()
-                                     ->isParamDestroyedInCallee()) {
+  if (argType->isRecordType() &&
+      argType->castAsRecordDecl()->isParamDestroyedInCallee()) {
     assert(!cir::MissingFeatures::msabi());
     cgm.errorNYI(e->getSourceRange(), "emitCallArg: msabi is NYI");
   }

diff  --git a/clang/lib/CIR/CodeGen/CIRGenClass.cpp b/clang/lib/CIR/CodeGen/CIRGenClass.cpp
index 3e5dc22426d8e..722027aa5631d 100644
--- a/clang/lib/CIR/CodeGen/CIRGenClass.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenClass.cpp
@@ -120,9 +120,7 @@ static void emitMemberInitializer(CIRGenFunction &cgf,
 
 static bool isInitializerOfDynamicClass(const CXXCtorInitializer *baseInit) {
   const Type *baseType = baseInit->getBaseClass();
-  const auto *baseClassDecl =
-      cast<CXXRecordDecl>(baseType->castAs<RecordType>()->getOriginalDecl())
-          ->getDefinitionOrSelf();
+  const auto *baseClassDecl = baseType->castAsCXXRecordDecl();
   return baseClassDecl->isDynamicClass();
 }
 
@@ -161,9 +159,7 @@ void CIRGenFunction::emitBaseInitializer(mlir::Location loc,
   Address thisPtr = loadCXXThisAddress();
 
   const Type *baseType = baseInit->getBaseClass();
-  const auto *baseClassDecl =
-      cast<CXXRecordDecl>(baseType->castAs<RecordType>()->getOriginalDecl())
-          ->getDefinitionOrSelf();
+  const auto *baseClassDecl = baseType->castAsCXXRecordDecl();
 
   bool isBaseVirtual = baseInit->isBaseVirtual();
 
@@ -568,9 +564,7 @@ void CIRGenFunction::emitImplicitAssignmentOperatorBody(FunctionArgList &args) {
 
 void CIRGenFunction::destroyCXXObject(CIRGenFunction &cgf, Address addr,
                                       QualType type) {
-  const RecordType *rtype = type->castAs<RecordType>();
-  const CXXRecordDecl *record =
-      cast<CXXRecordDecl>(rtype->getOriginalDecl())->getDefinitionOrSelf();
+  const auto *record = type->castAsCXXRecordDecl();
   const CXXDestructorDecl *dtor = record->getDestructor();
   // TODO(cir): Unlike traditional codegen, CIRGen should actually emit trivial
   // dtors which shall be removed on later CIR passes. However, only remove this

diff  --git a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp
index d679e1c1ff732..1afac6dd52c2d 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp
@@ -1092,11 +1092,7 @@ LValue CIRGenFunction::emitCastLValue(const CastExpr *e) {
 
   case CK_UncheckedDerivedToBase:
   case CK_DerivedToBase: {
-    const auto *derivedClassTy =
-        e->getSubExpr()->getType()->castAs<clang::RecordType>();
-    auto *derivedClassDecl =
-        cast<CXXRecordDecl>(derivedClassTy->getOriginalDecl())
-            ->getDefinitionOrSelf();
+    auto *derivedClassDecl = e->getSubExpr()->getType()->castAsCXXRecordDecl();
 
     LValue lv = emitLValue(e->getSubExpr());
     Address thisAddr = lv.getAddress();

diff  --git a/clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp b/clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp
index 2a3c98c458256..bab9ac73d7e65 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp
@@ -409,10 +409,7 @@ void AggExprEmitter::visitCXXParenListOrInitListExpr(
   // the disadvantage is that the generated code is more 
diff icult for
   // the optimizer, especially with bitfields.
   unsigned numInitElements = args.size();
-  RecordDecl *record = e->getType()
-                           ->castAs<RecordType>()
-                           ->getOriginalDecl()
-                           ->getDefinitionOrSelf();
+  auto *record = e->getType()->castAsRecordDecl();
 
   // We'll need to enter cleanup scopes in case any of the element
   // initializers throws an exception.

diff  --git a/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp b/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp
index 23132eae3214e..262d2548d5c39 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp
@@ -680,9 +680,7 @@ mlir::Attribute ConstantEmitter::tryEmitPrivateForVarInit(const VarDecl &d) {
         // assignments and whatnots). Since this is for globals shouldn't
         // be a problem for the near future.
         if (cd->isTrivial() && cd->isDefaultConstructor()) {
-          const auto *cxxrd =
-              cast<CXXRecordDecl>(ty->getAs<RecordType>()->getOriginalDecl())
-                  ->getDefinitionOrSelf();
+          const auto *cxxrd = ty->castAsCXXRecordDecl();
           if (cxxrd->getNumBases() != 0) {
             // There may not be anything additional to do here, but this will
             // force us to pause and test this path when it is supported.

diff  --git a/clang/lib/CIR/CodeGen/CIRGenModule.cpp b/clang/lib/CIR/CodeGen/CIRGenModule.cpp
index 08b40e08c51de..c7f548498c5cb 100644
--- a/clang/lib/CIR/CodeGen/CIRGenModule.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenModule.cpp
@@ -2180,10 +2180,7 @@ CharUnits CIRGenModule::computeNonVirtualBaseClassOffset(
     // Get the layout.
     const ASTRecordLayout &layout = astContext.getASTRecordLayout(rd);
 
-    const auto *baseDecl =
-        cast<CXXRecordDecl>(
-            base->getType()->castAs<clang::RecordType>()->getOriginalDecl())
-            ->getDefinitionOrSelf();
+    const auto *baseDecl = base->getType()->castAsCXXRecordDecl();
 
     // Add the offset.
     offset += layout.getBaseClassOffset(baseDecl);

diff  --git a/clang/lib/CIR/CodeGen/CIRGenTypes.cpp b/clang/lib/CIR/CodeGen/CIRGenTypes.cpp
index 44b631934ffaf..bb24933a22ed7 100644
--- a/clang/lib/CIR/CodeGen/CIRGenTypes.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenTypes.cpp
@@ -246,10 +246,7 @@ mlir::Type CIRGenTypes::convertRecordDeclType(const clang::RecordDecl *rd) {
     for (const auto &base : cxxRecordDecl->bases()) {
       if (base.isVirtual())
         continue;
-      convertRecordDeclType(base.getType()
-                                ->castAs<RecordType>()
-                                ->getOriginalDecl()
-                                ->getDefinitionOrSelf());
+      convertRecordDeclType(base.getType()->castAsRecordDecl());
     }
   }
 
@@ -465,8 +462,7 @@ mlir::Type CIRGenTypes::convertType(QualType type) {
   }
 
   case Type::Enum: {
-    const EnumDecl *ed =
-        cast<EnumType>(ty)->getOriginalDecl()->getDefinitionOrSelf();
+    const auto *ed = ty->castAsEnumDecl();
     if (auto integerType = ed->getIntegerType(); !integerType.isNull())
       return convertType(integerType);
     // Return a placeholder 'i32' type.  This can be changed later when the

diff  --git a/clang/lib/CodeGen/ABIInfoImpl.cpp b/clang/lib/CodeGen/ABIInfoImpl.cpp
index 989a3a4387cf4..72d359e4862b9 100644
--- a/clang/lib/CodeGen/ABIInfoImpl.cpp
+++ b/clang/lib/CodeGen/ABIInfoImpl.cpp
@@ -28,8 +28,8 @@ ABIArgInfo DefaultABIInfo::classifyArgumentType(QualType Ty) const {
   }
 
   // Treat an enum type as its underlying type.
-  if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-    Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = Ty->getAsEnumDecl())
+    Ty = ED->getIntegerType();
 
   ASTContext &Context = getContext();
   if (const auto *EIT = Ty->getAs<BitIntType>())
@@ -52,8 +52,8 @@ ABIArgInfo DefaultABIInfo::classifyReturnType(QualType RetTy) const {
     return getNaturalAlignIndirect(RetTy, getDataLayout().getAllocaAddrSpace());
 
   // Treat an enum type as its underlying type.
-  if (const EnumType *EnumTy = RetTy->getAs<EnumType>())
-    RetTy = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = RetTy->getAsEnumDecl())
+    RetTy = ED->getIntegerType();
 
   if (const auto *EIT = RetTy->getAs<BitIntType>())
     if (EIT->getNumBits() >

diff  --git a/clang/lib/CodeGen/CGCUDANV.cpp b/clang/lib/CodeGen/CGCUDANV.cpp
index c7f4bf8a21354..5090a0559eab2 100644
--- a/clang/lib/CodeGen/CGCUDANV.cpp
+++ b/clang/lib/CodeGen/CGCUDANV.cpp
@@ -1131,8 +1131,7 @@ void CGNVCUDARuntime::handleVarRegistration(const VarDecl *D,
     // Builtin surfaces and textures and their template arguments are
     // also registered with CUDA runtime.
     const auto *TD = cast<ClassTemplateSpecializationDecl>(
-                         D->getType()->castAs<RecordType>()->getOriginalDecl())
-                         ->getDefinitionOrSelf();
+        D->getType()->castAsCXXRecordDecl());
     const TemplateArgumentList &Args = TD->getTemplateArgs();
     if (TD->hasAttr<CUDADeviceBuiltinSurfaceTypeAttr>()) {
       assert(Args.size() == 2 &&

diff  --git a/clang/lib/CodeGen/CGCXX.cpp b/clang/lib/CodeGen/CGCXX.cpp
index f9aff893eb0f0..59aeff6804b61 100644
--- a/clang/lib/CodeGen/CGCXX.cpp
+++ b/clang/lib/CodeGen/CGCXX.cpp
@@ -83,9 +83,7 @@ bool CodeGenModule::TryEmitBaseDestructorAsAlias(const CXXDestructorDecl *D) {
     if (I.isVirtual()) continue;
 
     // Skip base classes with trivial destructors.
-    const auto *Base = cast<CXXRecordDecl>(
-                           I.getType()->castAs<RecordType>()->getOriginalDecl())
-                           ->getDefinitionOrSelf();
+    const auto *Base = I.getType()->castAsCXXRecordDecl();
     if (Base->hasTrivialDestructor()) continue;
 
     // If we've already found a base class with a non-trivial
@@ -281,16 +279,8 @@ static CGCallee BuildAppleKextVirtualCall(CodeGenFunction &CGF,
 CGCallee CodeGenFunction::BuildAppleKextVirtualCall(const CXXMethodDecl *MD,
                                                     NestedNameSpecifier Qual,
                                                     llvm::Type *Ty) {
-  assert(Qual.getKind() == NestedNameSpecifier::Kind::Type &&
-         "BuildAppleKextVirtualCall - bad Qual kind");
-
-  const Type *QTy = Qual.getAsType();
-  QualType T = QualType(QTy, 0);
-  const RecordType *RT = T->getAs<RecordType>();
-  assert(RT && "BuildAppleKextVirtualCall - Qual type must be record");
-  const auto *RD =
-      cast<CXXRecordDecl>(RT->getOriginalDecl())->getDefinitionOrSelf();
-
+  const CXXRecordDecl *RD = Qual.getAsRecordDecl();
+  assert(RD && "BuildAppleKextVirtualCall - Qual must be record");
   if (const auto *DD = dyn_cast<CXXDestructorDecl>(MD))
     return BuildAppleKextVirtualDestructorCall(DD, Dtor_Complete, RD);
 

diff  --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index 89eb6d2b393a3..f61d3d987b3c9 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -2873,10 +2873,7 @@ void CodeGenModule::ConstructAttributeList(StringRef Name,
           // whose destruction / clean-up is carried out within the callee
           // (e.g., Obj-C ARC-managed structs, MSVC callee-destroyed objects).
           if (!ParamType.isDestructedType() || !ParamType->isRecordType() ||
-              ParamType->castAs<RecordType>()
-                  ->getOriginalDecl()
-                  ->getDefinitionOrSelf()
-                  ->isParamDestroyedInCallee())
+              ParamType->castAsRecordDecl()->isParamDestroyedInCallee())
             Attrs.addAttribute(llvm::Attribute::DeadOnReturn);
         }
       }
@@ -4307,10 +4304,7 @@ void CodeGenFunction::EmitDelegateCallArg(CallArgList &args,
 
   // Deactivate the cleanup for the callee-destructed param that was pushed.
   if (type->isRecordType() && !CurFuncIsThunk &&
-      type->castAs<RecordType>()
-          ->getOriginalDecl()
-          ->getDefinitionOrSelf()
-          ->isParamDestroyedInCallee() &&
+      type->castAsRecordDecl()->isParamDestroyedInCallee() &&
       param->needsDestruction(getContext())) {
     EHScopeStack::stable_iterator cleanup =
         CalleeDestructedParamCleanups.lookup(cast<ParmVarDecl>(param));
@@ -4903,10 +4897,8 @@ void CodeGenFunction::EmitCallArg(CallArgList &args, const Expr *E,
   // In the Microsoft C++ ABI, aggregate arguments are destructed by the callee.
   // However, we still have to push an EH-only cleanup in case we unwind before
   // we make it to the call.
-  if (type->isRecordType() && type->castAs<RecordType>()
-                                  ->getOriginalDecl()
-                                  ->getDefinitionOrSelf()
-                                  ->isParamDestroyedInCallee()) {
+  if (type->isRecordType() &&
+      type->castAsRecordDecl()->isParamDestroyedInCallee()) {
     // If we're using inalloca, use the argument memory.  Otherwise, use a
     // temporary.
     AggValueSlot Slot = args.isUsingInAlloca()

diff  --git a/clang/lib/CodeGen/CGClass.cpp b/clang/lib/CodeGen/CGClass.cpp
index 10e4543a6ab20..bae55aa1e1928 100644
--- a/clang/lib/CodeGen/CGClass.cpp
+++ b/clang/lib/CodeGen/CGClass.cpp
@@ -180,11 +180,7 @@ CharUnits CodeGenModule::computeNonVirtualBaseClassOffset(
     // Get the layout.
     const ASTRecordLayout &Layout = Context.getASTRecordLayout(RD);
 
-    const auto *BaseDecl =
-        cast<CXXRecordDecl>(
-            Base->getType()->castAs<RecordType>()->getOriginalDecl())
-            ->getDefinitionOrSelf();
-
+    const auto *BaseDecl = Base->getType()->castAsCXXRecordDecl();
     // Add the offset.
     Offset += Layout.getBaseClassOffset(BaseDecl);
 
@@ -302,9 +298,7 @@ Address CodeGenFunction::GetAddressOfBaseClass(
   // *start* with a step down to the correct virtual base subobject,
   // and hence will not require any further steps.
   if ((*Start)->isVirtual()) {
-    VBase = cast<CXXRecordDecl>(
-                (*Start)->getType()->castAs<RecordType>()->getOriginalDecl())
-                ->getDefinitionOrSelf();
+    VBase = (*Start)->getType()->castAsCXXRecordDecl();
     ++Start;
   }
 
@@ -559,10 +553,7 @@ static void EmitBaseInitializer(CodeGenFunction &CGF,
 
   Address ThisPtr = CGF.LoadCXXThisAddress();
 
-  const Type *BaseType = BaseInit->getBaseClass();
-  const auto *BaseClassDecl =
-      cast<CXXRecordDecl>(BaseType->castAs<RecordType>()->getOriginalDecl())
-          ->getDefinitionOrSelf();
+  const auto *BaseClassDecl = BaseInit->getBaseClass()->castAsCXXRecordDecl();
 
   bool isBaseVirtual = BaseInit->isBaseVirtual();
 
@@ -1267,10 +1258,7 @@ namespace {
 
 static bool isInitializerOfDynamicClass(const CXXCtorInitializer *BaseInit) {
   const Type *BaseType = BaseInit->getBaseClass();
-  const auto *BaseClassDecl =
-      cast<CXXRecordDecl>(BaseType->castAs<RecordType>()->getOriginalDecl())
-          ->getDefinitionOrSelf();
-  return BaseClassDecl->isDynamicClass();
+  return BaseType->castAsCXXRecordDecl()->isDynamicClass();
 }
 
 /// EmitCtorPrologue - This routine generates necessary code to initialize
@@ -1377,10 +1365,7 @@ HasTrivialDestructorBody(ASTContext &Context,
     if (I.isVirtual())
       continue;
 
-    const CXXRecordDecl *NonVirtualBase =
-        cast<CXXRecordDecl>(
-            I.getType()->castAs<RecordType>()->getOriginalDecl())
-            ->getDefinitionOrSelf();
+    const auto *NonVirtualBase = I.getType()->castAsCXXRecordDecl();
     if (!HasTrivialDestructorBody(Context, NonVirtualBase,
                                   MostDerivedClassDecl))
       return false;
@@ -1389,10 +1374,7 @@ HasTrivialDestructorBody(ASTContext &Context,
   if (BaseClassDecl == MostDerivedClassDecl) {
     // Check virtual bases.
     for (const auto &I : BaseClassDecl->vbases()) {
-      const auto *VirtualBase =
-          cast<CXXRecordDecl>(
-              I.getType()->castAs<RecordType>()->getOriginalDecl())
-              ->getDefinitionOrSelf();
+      const auto *VirtualBase = I.getType()->castAsCXXRecordDecl();
       if (!HasTrivialDestructorBody(Context, VirtualBase,
                                     MostDerivedClassDecl))
         return false;
@@ -1904,11 +1886,7 @@ void CodeGenFunction::EnterDtorCleanups(const CXXDestructorDecl *DD,
     // We push them in the forward order so that they'll be popped in
     // the reverse order.
     for (const auto &Base : ClassDecl->vbases()) {
-      auto *BaseClassDecl =
-          cast<CXXRecordDecl>(
-              Base.getType()->castAs<RecordType>()->getOriginalDecl())
-              ->getDefinitionOrSelf();
-
+      auto *BaseClassDecl = Base.getType()->castAsCXXRecordDecl();
       if (BaseClassDecl->hasTrivialDestructor()) {
         // Under SanitizeMemoryUseAfterDtor, poison the trivial base class
         // memory. For non-trival base classes the same is done in the class
@@ -2127,10 +2105,7 @@ void CodeGenFunction::EmitCXXAggrConstructorCall(const CXXConstructorDecl *ctor,
 void CodeGenFunction::destroyCXXObject(CodeGenFunction &CGF,
                                        Address addr,
                                        QualType type) {
-  const RecordType *rtype = type->castAs<RecordType>();
-  const auto *record =
-      cast<CXXRecordDecl>(rtype->getOriginalDecl())->getDefinitionOrSelf();
-  const CXXDestructorDecl *dtor = record->getDestructor();
+  const CXXDestructorDecl *dtor = type->castAsCXXRecordDecl()->getDestructor();
   assert(!dtor->isTrivial());
   CGF.EmitCXXDestructorCall(dtor, Dtor_Complete, /*for vbase*/ false,
                             /*Delegating=*/false, addr, type);
@@ -2649,10 +2624,7 @@ void CodeGenFunction::getVTablePointers(BaseSubobject Base,
 
   // Traverse bases.
   for (const auto &I : RD->bases()) {
-    auto *BaseDecl = cast<CXXRecordDecl>(
-                         I.getType()->castAs<RecordType>()->getOriginalDecl())
-                         ->getDefinitionOrSelf();
-
+    auto *BaseDecl = I.getType()->castAsCXXRecordDecl();
     // Ignore classes without a vtable.
     if (!BaseDecl->isDynamicClass())
       continue;

diff  --git a/clang/lib/CodeGen/CGDebugInfo.cpp b/clang/lib/CodeGen/CGDebugInfo.cpp
index c44fea3e6383d..560749f8ab308 100644
--- a/clang/lib/CodeGen/CGDebugInfo.cpp
+++ b/clang/lib/CodeGen/CGDebugInfo.cpp
@@ -5921,8 +5921,7 @@ void CGDebugInfo::EmitGlobalVariable(llvm::GlobalVariable *Var,
   // variable for each member of the anonymous union so that it's possible
   // to find the name of any field in the union.
   if (T->isUnionType() && DeclName.empty()) {
-    const RecordDecl *RD =
-        T->castAs<RecordType>()->getOriginalDecl()->getDefinitionOrSelf();
+    const auto *RD = T->castAsRecordDecl();
     assert(RD->isAnonymousStructOrUnion() &&
            "unnamed non-anonymous struct or union?");
     GVE = CollectAnonRecordDecls(RD, Unit, LineNo, LinkageName, Var, DContext);

diff  --git a/clang/lib/CodeGen/CGDecl.cpp b/clang/lib/CodeGen/CGDecl.cpp
index 8693f3c71ea80..29193e0c541b9 100644
--- a/clang/lib/CodeGen/CGDecl.cpp
+++ b/clang/lib/CodeGen/CGDecl.cpp
@@ -2727,10 +2727,7 @@ void CodeGenFunction::EmitParmDecl(const VarDecl &D, ParamValue Arg,
     // Don't push a cleanup in a thunk for a method that will also emit a
     // cleanup.
     if (Ty->isRecordType() && !CurFuncIsThunk &&
-        Ty->castAs<RecordType>()
-            ->getOriginalDecl()
-            ->getDefinitionOrSelf()
-            ->isParamDestroyedInCallee()) {
+        Ty->castAsRecordDecl()->isParamDestroyedInCallee()) {
       if (QualType::DestructionKind DtorKind =
               D.needsDestruction(getContext())) {
         assert((DtorKind == QualType::DK_cxx_destructor ||

diff  --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index e8942aceb9b09..eeaf68dfd0521 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -1984,11 +1984,9 @@ llvm::Value *CodeGenFunction::EmitLoadOfScalar(LValue lvalue,
 static bool getRangeForType(CodeGenFunction &CGF, QualType Ty,
                             llvm::APInt &Min, llvm::APInt &End,
                             bool StrictEnums, bool IsBool) {
-  const EnumType *ET = Ty->getAs<EnumType>();
-  const EnumDecl *ED =
-      ET ? ET->getOriginalDecl()->getDefinitionOrSelf() : nullptr;
+  const auto *ED = Ty->getAsEnumDecl();
   bool IsRegularCPlusPlusEnum =
-      CGF.getLangOpts().CPlusPlus && StrictEnums && ET && !ED->isFixed();
+      CGF.getLangOpts().CPlusPlus && StrictEnums && ED && !ED->isFixed();
   if (!IsBool && !IsRegularCPlusPlusEnum)
     return false;
 
@@ -5723,12 +5721,7 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) {
 
   case CK_UncheckedDerivedToBase:
   case CK_DerivedToBase: {
-    const auto *DerivedClassTy =
-        E->getSubExpr()->getType()->castAs<RecordType>();
-    auto *DerivedClassDecl =
-        cast<CXXRecordDecl>(DerivedClassTy->getOriginalDecl())
-            ->getDefinitionOrSelf();
-
+    auto *DerivedClassDecl = E->getSubExpr()->getType()->castAsCXXRecordDecl();
     LValue LV = EmitLValue(E->getSubExpr());
     Address This = LV.getAddress();
 
@@ -5746,11 +5739,7 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) {
   case CK_ToUnion:
     return EmitAggExprToLValue(E);
   case CK_BaseToDerived: {
-    const auto *DerivedClassTy = E->getType()->castAs<RecordType>();
-    auto *DerivedClassDecl =
-        cast<CXXRecordDecl>(DerivedClassTy->getOriginalDecl())
-            ->getDefinitionOrSelf();
-
+    auto *DerivedClassDecl = E->getType()->castAsCXXRecordDecl();
     LValue LV = EmitLValue(E->getSubExpr());
 
     // Perform the base-to-derived conversion

diff  --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index ed90e987b180d..50575362d304b 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -424,10 +424,7 @@ AggExprEmitter::VisitCXXStdInitializerListExpr(CXXStdInitializerListExpr *E) {
       Ctx.getAsConstantArrayType(E->getSubExpr()->getType());
   assert(ArrayType && "std::initializer_list constructed from non-array");
 
-  RecordDecl *Record = E->getType()
-                           ->castAs<RecordType>()
-                           ->getOriginalDecl()
-                           ->getDefinitionOrSelf();
+  auto *Record = E->getType()->castAsRecordDecl();
   RecordDecl::field_iterator Field = Record->field_begin();
   assert(Field != Record->field_end() &&
          Ctx.hasSameType(Field->getType()->getPointeeType(),
@@ -1809,10 +1806,7 @@ void AggExprEmitter::VisitCXXParenListOrInitListExpr(
   // the disadvantage is that the generated code is more 
diff icult for
   // the optimizer, especially with bitfields.
   unsigned NumInitElements = InitExprs.size();
-  RecordDecl *record = ExprToVisit->getType()
-                           ->castAs<RecordType>()
-                           ->getOriginalDecl()
-                           ->getDefinitionOrSelf();
+  RecordDecl *record = ExprToVisit->getType()->castAsRecordDecl();
 
   // We'll need to enter cleanup scopes in case any of the element
   // initializers throws an exception.

diff  --git a/clang/lib/CodeGen/CGExprCXX.cpp b/clang/lib/CodeGen/CGExprCXX.cpp
index b6868c1fb1f42..810a2ab8c1fcd 100644
--- a/clang/lib/CodeGen/CGExprCXX.cpp
+++ b/clang/lib/CodeGen/CGExprCXX.cpp
@@ -180,8 +180,7 @@ static CXXRecordDecl *getCXXRecord(const Expr *E) {
   QualType T = E->getType();
   if (const PointerType *PTy = T->getAs<PointerType>())
     T = PTy->getPointeeType();
-  const RecordType *Ty = T->castAs<RecordType>();
-  return cast<CXXRecordDecl>(Ty->getOriginalDecl())->getDefinitionOrSelf();
+  return T->castAsCXXRecordDecl();
 }
 
 // Note: This function also emit constructor calls to support a MSVC
@@ -1687,11 +1686,8 @@ llvm::Value *CodeGenFunction::EmitCXXNewExpr(const CXXNewExpr *E) {
       QualType AlignValT = sizeType;
       if (allocatorType->getNumParams() > IndexOfAlignArg) {
         AlignValT = allocatorType->getParamType(IndexOfAlignArg);
-        assert(getContext().hasSameUnqualifiedType(AlignValT->castAs<EnumType>()
-                                                       ->getOriginalDecl()
-                                                       ->getDefinitionOrSelf()
-                                                       ->getIntegerType(),
-                                                   sizeType) &&
+        assert(getContext().hasSameUnqualifiedType(
+                   AlignValT->castAsEnumDecl()->getIntegerType(), sizeType) &&
                "wrong type for alignment parameter");
         ++ParamsToSkip;
       } else {

diff  --git a/clang/lib/CodeGen/CGExprConstant.cpp b/clang/lib/CodeGen/CGExprConstant.cpp
index c4679bd734de9..b44dd9ecc717e 100644
--- a/clang/lib/CodeGen/CGExprConstant.cpp
+++ b/clang/lib/CodeGen/CGExprConstant.cpp
@@ -714,10 +714,7 @@ static bool EmitDesignatedInitUpdater(ConstantEmitter &Emitter,
 }
 
 bool ConstStructBuilder::Build(const InitListExpr *ILE, bool AllowOverwrite) {
-  RecordDecl *RD = ILE->getType()
-                       ->castAs<RecordType>()
-                       ->getOriginalDecl()
-                       ->getDefinitionOrSelf();
+  auto *RD = ILE->getType()->castAsRecordDecl();
   const ASTRecordLayout &Layout = CGM.getContext().getASTRecordLayout(RD);
 
   unsigned FieldNo = -1;
@@ -981,8 +978,7 @@ bool ConstStructBuilder::DoZeroInitPadding(const ASTRecordLayout &Layout,
 
 llvm::Constant *ConstStructBuilder::Finalize(QualType Type) {
   Type = Type.getNonReferenceType();
-  RecordDecl *RD =
-      Type->castAs<RecordType>()->getOriginalDecl()->getDefinitionOrSelf();
+  auto *RD = Type->castAsRecordDecl();
   llvm::Type *ValTy = CGM.getTypes().ConvertType(Type);
   return Builder.build(ValTy, RD->hasFlexibleArrayMember());
 }
@@ -1005,8 +1001,7 @@ llvm::Constant *ConstStructBuilder::BuildStruct(ConstantEmitter &Emitter,
   ConstantAggregateBuilder Const(Emitter.CGM);
   ConstStructBuilder Builder(Emitter, Const, CharUnits::Zero());
 
-  const RecordDecl *RD =
-      ValTy->castAs<RecordType>()->getOriginalDecl()->getDefinitionOrSelf();
+  const auto *RD = ValTy->castAsRecordDecl();
   const CXXRecordDecl *CD = dyn_cast<CXXRecordDecl>(RD);
   if (!Builder.Build(Val, RD, false, CD, CharUnits::Zero()))
     return nullptr;
@@ -2645,11 +2640,7 @@ static llvm::Constant *EmitNullConstant(CodeGenModule &CGM,
         continue;
       }
 
-      const CXXRecordDecl *base =
-          cast<CXXRecordDecl>(
-              I.getType()->castAs<RecordType>()->getOriginalDecl())
-              ->getDefinitionOrSelf();
-
+      const auto *base = I.getType()->castAsCXXRecordDecl();
       // Ignore empty bases.
       if (isEmptyRecordForLayout(CGM.getContext(), I.getType()) ||
           CGM.getContext()
@@ -2687,11 +2678,7 @@ static llvm::Constant *EmitNullConstant(CodeGenModule &CGM,
   // Fill in the virtual bases, if we're working with the complete object.
   if (CXXR && asCompleteObject) {
     for (const auto &I : CXXR->vbases()) {
-      const auto *base =
-          cast<CXXRecordDecl>(
-              I.getType()->castAs<RecordType>()->getOriginalDecl())
-              ->getDefinitionOrSelf();
-
+      const auto *base = I.getType()->castAsCXXRecordDecl();
       // Ignore empty bases.
       if (isEmptyRecordForLayout(CGM.getContext(), I.getType()))
         continue;

diff  --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 1e23f45546748..43d295599c4c8 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -3515,9 +3515,7 @@ Value *ScalarExprEmitter::VisitOffsetOfExpr(OffsetOfExpr *E) {
 
     case OffsetOfNode::Field: {
       FieldDecl *MemberDecl = ON.getField();
-      RecordDecl *RD = CurrentType->castAs<RecordType>()
-                           ->getOriginalDecl()
-                           ->getDefinitionOrSelf();
+      auto *RD = CurrentType->castAsRecordDecl();
       const ASTRecordLayout &RL = CGF.getContext().getASTRecordLayout(RD);
 
       // Compute the index of the field in its parent.
@@ -3557,9 +3555,7 @@ Value *ScalarExprEmitter::VisitOffsetOfExpr(OffsetOfExpr *E) {
       CurrentType = ON.getBase()->getType();
 
       // Compute the offset to the base.
-      auto *BaseRT = CurrentType->castAs<RecordType>();
-      auto *BaseRD =
-          cast<CXXRecordDecl>(BaseRT->getOriginalDecl())->getDefinitionOrSelf();
+      auto *BaseRD = CurrentType->castAsCXXRecordDecl();
       CharUnits OffsetInt = RL.getBaseClassOffset(BaseRD);
       Offset = llvm::ConstantInt::get(ResultType, OffsetInt.getQuantity());
       break;

diff  --git a/clang/lib/CodeGen/CGNonTrivialStruct.cpp b/clang/lib/CodeGen/CGNonTrivialStruct.cpp
index 1b941fff8b644..b78d89fd1e348 100644
--- a/clang/lib/CodeGen/CGNonTrivialStruct.cpp
+++ b/clang/lib/CodeGen/CGNonTrivialStruct.cpp
@@ -39,8 +39,7 @@ template <class Derived> struct StructVisitor {
 
   template <class... Ts>
   void visitStructFields(QualType QT, CharUnits CurStructOffset, Ts... Args) {
-    const RecordDecl *RD =
-        QT->castAs<RecordType>()->getOriginalDecl()->getDefinitionOrSelf();
+    const auto *RD = QT->castAsRecordDecl();
 
     // Iterate over the fields of the struct.
     for (const FieldDecl *FD : RD->fields()) {

diff  --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index f98339d472fa9..b66608319bb51 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -3006,10 +3006,10 @@ emitProxyTaskFunction(CodeGenModule &CGM, SourceLocation Loc,
       CGF.GetAddrOfLocalVar(&TaskTypeArg),
       KmpTaskTWithPrivatesPtrQTy->castAs<PointerType>());
   const auto *KmpTaskTWithPrivatesQTyRD =
-      cast<RecordDecl>(KmpTaskTWithPrivatesQTy->getAsTagDecl());
+      KmpTaskTWithPrivatesQTy->castAsRecordDecl();
   LValue Base =
       CGF.EmitLValueForField(TDBase, *KmpTaskTWithPrivatesQTyRD->field_begin());
-  const auto *KmpTaskTQTyRD = cast<RecordDecl>(KmpTaskTQTy->getAsTagDecl());
+  const auto *KmpTaskTQTyRD = KmpTaskTQTy->castAsRecordDecl();
   auto PartIdFI = std::next(KmpTaskTQTyRD->field_begin(), KmpTaskTPartId);
   LValue PartIdLVal = CGF.EmitLValueForField(Base, *PartIdFI);
   llvm::Value *PartidParam = PartIdLVal.getPointer(CGF);
@@ -3104,11 +3104,10 @@ static llvm::Value *emitDestructorsFunction(CodeGenModule &CGM,
       CGF.GetAddrOfLocalVar(&TaskTypeArg),
       KmpTaskTWithPrivatesPtrQTy->castAs<PointerType>());
   const auto *KmpTaskTWithPrivatesQTyRD =
-      cast<RecordDecl>(KmpTaskTWithPrivatesQTy->getAsTagDecl());
+      KmpTaskTWithPrivatesQTy->castAsRecordDecl();
   auto FI = std::next(KmpTaskTWithPrivatesQTyRD->field_begin());
   Base = CGF.EmitLValueForField(Base, *FI);
-  for (const auto *Field :
-       cast<RecordDecl>(FI->getType()->getAsTagDecl())->fields()) {
+  for (const auto *Field : FI->getType()->castAsRecordDecl()->fields()) {
     if (QualType::DestructionKind DtorKind =
             Field->getType().isDestructedType()) {
       LValue FieldLValue = CGF.EmitLValueForField(Base, Field);
@@ -3212,7 +3211,7 @@ emitTaskPrivateMappingFunction(CodeGenModule &CGM, SourceLocation Loc,
   LValue Base = CGF.EmitLoadOfPointerLValue(
       CGF.GetAddrOfLocalVar(&TaskPrivatesArg),
       TaskPrivatesArg.getType()->castAs<PointerType>());
-  const auto *PrivatesQTyRD = cast<RecordDecl>(PrivatesQTy->getAsTagDecl());
+  const auto *PrivatesQTyRD = PrivatesQTy->castAsRecordDecl();
   Counter = 0;
   for (const FieldDecl *Field : PrivatesQTyRD->fields()) {
     LValue FieldLVal = CGF.EmitLValueForField(Base, Field);
@@ -3259,7 +3258,7 @@ static void emitPrivatesInit(CodeGenFunction &CGF,
             CGF.ConvertTypeForMem(SharedsTy)),
         SharedsTy);
   }
-  FI = cast<RecordDecl>(FI->getType()->getAsTagDecl())->field_begin();
+  FI = FI->getType()->castAsRecordDecl()->field_begin();
   for (const PrivateDataTy &Pair : Privates) {
     // Do not initialize private locals.
     if (Pair.second.isLocalPrivate()) {
@@ -3655,7 +3654,7 @@ CGOpenMPRuntime::emitTaskInit(CodeGenFunction &CGF, SourceLocation Loc,
     }
     KmpTaskTQTy = SavedKmpTaskTQTy;
   }
-  const auto *KmpTaskTQTyRD = cast<RecordDecl>(KmpTaskTQTy->getAsTagDecl());
+  const auto *KmpTaskTQTyRD = KmpTaskTQTy->castAsRecordDecl();
   // Build particular struct kmp_task_t for the given task.
   const RecordDecl *KmpTaskTWithPrivatesQTyRD =
       createKmpTaskTWithPrivatesRecordDecl(CGM, KmpTaskTQTy, Privates);
@@ -3915,10 +3914,7 @@ CGOpenMPRuntime::emitTaskInit(CodeGenFunction &CGF, SourceLocation Loc,
   // Fill the data in the resulting kmp_task_t record.
   // Copy shareds if there are any.
   Address KmpTaskSharedsPtr = Address::invalid();
-  if (!SharedsTy->getAsStructureType()
-           ->getOriginalDecl()
-           ->getDefinitionOrSelf()
-           ->field_empty()) {
+  if (!SharedsTy->castAsRecordDecl()->field_empty()) {
     KmpTaskSharedsPtr = Address(
         CGF.EmitLoadOfScalar(
             CGF.EmitLValueForField(
@@ -3948,11 +3944,8 @@ CGOpenMPRuntime::emitTaskInit(CodeGenFunction &CGF, SourceLocation Loc,
   enum { Priority = 0, Destructors = 1 };
   // Provide pointer to function with destructors for privates.
   auto FI = std::next(KmpTaskTQTyRD->field_begin(), Data1);
-  const RecordDecl *KmpCmplrdataUD = (*FI)
-                                         ->getType()
-                                         ->getAsUnionType()
-                                         ->getOriginalDecl()
-                                         ->getDefinitionOrSelf();
+  const auto *KmpCmplrdataUD = (*FI)->getType()->castAsRecordDecl();
+  assert(KmpCmplrdataUD->isUnion());
   if (NeedsCleanup) {
     llvm::Value *DestructorFn = emitDestructorsFunction(
         CGM, Loc, KmpInt32Ty, KmpTaskTWithPrivatesPtrQTy,
@@ -4032,8 +4025,7 @@ CGOpenMPRuntime::getDepobjElements(CodeGenFunction &CGF, LValue DepobjLVal,
   ASTContext &C = CGM.getContext();
   QualType FlagsTy;
   getDependTypes(C, KmpDependInfoTy, FlagsTy);
-  RecordDecl *KmpDependInfoRD =
-      cast<RecordDecl>(KmpDependInfoTy->getAsTagDecl());
+  auto *KmpDependInfoRD = KmpDependInfoTy->castAsRecordDecl();
   QualType KmpDependInfoPtrTy = C.getPointerType(KmpDependInfoTy);
   LValue Base = CGF.EmitLoadOfPointerLValue(
       DepobjLVal.getAddress().withElementType(
@@ -4061,8 +4053,7 @@ static void emitDependData(CodeGenFunction &CGF, QualType &KmpDependInfoTy,
   ASTContext &C = CGM.getContext();
   QualType FlagsTy;
   getDependTypes(C, KmpDependInfoTy, FlagsTy);
-  RecordDecl *KmpDependInfoRD =
-      cast<RecordDecl>(KmpDependInfoTy->getAsTagDecl());
+  auto *KmpDependInfoRD = KmpDependInfoTy->castAsRecordDecl();
   llvm::Type *LLVMFlagsTy = CGF.ConvertTypeForMem(FlagsTy);
 
   OMPIteratorGeneratorScope IteratorScope(
@@ -4333,8 +4324,7 @@ Address CGOpenMPRuntime::emitDepobjDependClause(
   unsigned NumDependencies = Dependencies.DepExprs.size();
   QualType FlagsTy;
   getDependTypes(C, KmpDependInfoTy, FlagsTy);
-  RecordDecl *KmpDependInfoRD =
-      cast<RecordDecl>(KmpDependInfoTy->getAsTagDecl());
+  auto *KmpDependInfoRD = KmpDependInfoTy->castAsRecordDecl();
 
   llvm::Value *Size;
   // Define type kmp_depend_info[<Dependencies.size()>];
@@ -4442,8 +4432,7 @@ void CGOpenMPRuntime::emitUpdateClause(CodeGenFunction &CGF, LValue DepobjLVal,
   ASTContext &C = CGM.getContext();
   QualType FlagsTy;
   getDependTypes(C, KmpDependInfoTy, FlagsTy);
-  RecordDecl *KmpDependInfoRD =
-      cast<RecordDecl>(KmpDependInfoTy->getAsTagDecl());
+  auto *KmpDependInfoRD = KmpDependInfoTy->castAsRecordDecl();
   llvm::Type *LLVMFlagsTy = CGF.ConvertTypeForMem(FlagsTy);
   llvm::Value *NumDeps;
   LValue Base;
@@ -11319,7 +11308,7 @@ void CGOpenMPRuntime::emitDoacrossInit(CodeGenFunction &CGF,
     RD->completeDefinition();
     KmpDimTy = C.getCanonicalTagType(RD);
   } else {
-    RD = cast<RecordDecl>(KmpDimTy->getAsTagDecl());
+    RD = KmpDimTy->castAsRecordDecl();
   }
   llvm::APInt Size(/*numBits=*/32, NumIterations.size());
   QualType ArrayTy = C.getConstantArrayType(KmpDimTy, Size, nullptr,

diff  --git a/clang/lib/CodeGen/CodeGenTypes.cpp b/clang/lib/CodeGen/CodeGenTypes.cpp
index d3fbb4a2a9364..3ffe999d01178 100644
--- a/clang/lib/CodeGen/CodeGenTypes.cpp
+++ b/clang/lib/CodeGen/CodeGenTypes.cpp
@@ -700,8 +700,7 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
     break;
 
   case Type::Enum: {
-    const EnumDecl *ED =
-        cast<EnumType>(Ty)->getOriginalDecl()->getDefinitionOrSelf();
+    const auto *ED = Ty->castAsEnumDecl();
     if (ED->isCompleteDefinition() || ED->isFixed())
       return ConvertType(ED->getIntegerType());
     // Return a placeholder 'i32' type.  This can be changed later when the
@@ -814,10 +813,7 @@ llvm::StructType *CodeGenTypes::ConvertRecordDeclType(const RecordDecl *RD) {
   if (const CXXRecordDecl *CRD = dyn_cast<CXXRecordDecl>(RD)) {
     for (const auto &I : CRD->bases()) {
       if (I.isVirtual()) continue;
-      ConvertRecordDeclType(I.getType()
-                                ->castAs<RecordType>()
-                                ->getOriginalDecl()
-                                ->getDefinitionOrSelf());
+      ConvertRecordDeclType(I.getType()->castAsRecordDecl());
     }
   }
 

diff  --git a/clang/lib/CodeGen/HLSLBufferLayoutBuilder.cpp b/clang/lib/CodeGen/HLSLBufferLayoutBuilder.cpp
index ac56dda74abb7..4e25e887c4133 100644
--- a/clang/lib/CodeGen/HLSLBufferLayoutBuilder.cpp
+++ b/clang/lib/CodeGen/HLSLBufferLayoutBuilder.cpp
@@ -85,24 +85,22 @@ llvm::TargetExtType *HLSLBufferLayoutBuilder::createLayoutType(
   Layout.push_back(0);
 
   // iterate over all fields of the record, including fields on base classes
-  llvm::SmallVector<const RecordType *> RecordTypes;
-  RecordTypes.push_back(RT);
-  while (RecordTypes.back()->getAsCXXRecordDecl()->getNumBases()) {
-    CXXRecordDecl *D = RecordTypes.back()->getAsCXXRecordDecl();
+  llvm::SmallVector<CXXRecordDecl *> RecordDecls;
+  RecordDecls.push_back(RT->castAsCXXRecordDecl());
+  while (RecordDecls.back()->getNumBases()) {
+    CXXRecordDecl *D = RecordDecls.back();
     assert(D->getNumBases() == 1 &&
            "HLSL doesn't support multiple inheritance");
-    RecordTypes.push_back(D->bases_begin()->getType()->getAs<RecordType>());
+    RecordDecls.push_back(D->bases_begin()->getType()->castAsCXXRecordDecl());
   }
 
   unsigned FieldOffset;
   llvm::Type *FieldType;
 
-  while (!RecordTypes.empty()) {
-    const RecordType *RT = RecordTypes.back();
-    RecordTypes.pop_back();
+  while (!RecordDecls.empty()) {
+    const CXXRecordDecl *RD = RecordDecls.pop_back_val();
 
-    for (const auto *FD :
-         RT->getOriginalDecl()->getDefinitionOrSelf()->fields()) {
+    for (const auto *FD : RD->fields()) {
       assert((!PackOffsets || Index < PackOffsets->size()) &&
              "number of elements in layout struct does not match number of "
              "packoffset annotations");
@@ -148,7 +146,7 @@ llvm::TargetExtType *HLSLBufferLayoutBuilder::createLayoutType(
 
   // create the layout struct type; anonymous struct have empty name but
   // non-empty qualified name
-  const CXXRecordDecl *Decl = RT->getAsCXXRecordDecl();
+  const auto *Decl = RT->castAsCXXRecordDecl();
   std::string Name =
       Decl->getName().empty() ? "anon" : Decl->getQualifiedNameAsString();
   llvm::StructType *StructTy =

diff  --git a/clang/lib/CodeGen/ItaniumCXXABI.cpp b/clang/lib/CodeGen/ItaniumCXXABI.cpp
index 569fbe9e3bd3a..885b700ffa193 100644
--- a/clang/lib/CodeGen/ItaniumCXXABI.cpp
+++ b/clang/lib/CodeGen/ItaniumCXXABI.cpp
@@ -1397,9 +1397,7 @@ void ItaniumCXXABI::emitVirtualObjectDelete(CodeGenFunction &CGF,
     // to pass to the deallocation function.
 
     // Grab the vtable pointer as an intptr_t*.
-    auto *ClassDecl = cast<CXXRecordDecl>(
-                          ElementType->castAs<RecordType>()->getOriginalDecl())
-                          ->getDefinitionOrSelf();
+    auto *ClassDecl = ElementType->castAsCXXRecordDecl();
     llvm::Value *VTable = CGF.GetVTablePtr(Ptr, CGF.UnqualPtrTy, ClassDecl);
 
     // Track back to entry -2 and pull out the offset there.
@@ -1609,9 +1607,7 @@ llvm::Value *ItaniumCXXABI::EmitTypeid(CodeGenFunction &CGF,
                                        QualType SrcRecordTy,
                                        Address ThisPtr,
                                        llvm::Type *StdTypeInfoPtrTy) {
-  auto *ClassDecl =
-      cast<CXXRecordDecl>(SrcRecordTy->castAs<RecordType>()->getOriginalDecl())
-          ->getDefinitionOrSelf();
+  auto *ClassDecl = SrcRecordTy->castAsCXXRecordDecl();
   llvm::Value *Value = CGF.GetVTablePtr(ThisPtr, CGM.GlobalsInt8PtrTy,
                                         ClassDecl);
 
@@ -1783,9 +1779,7 @@ llvm::Value *ItaniumCXXABI::emitExactDynamicCast(
 llvm::Value *ItaniumCXXABI::emitDynamicCastToVoid(CodeGenFunction &CGF,
                                                   Address ThisAddr,
                                                   QualType SrcRecordTy) {
-  auto *ClassDecl =
-      cast<CXXRecordDecl>(SrcRecordTy->castAs<RecordType>()->getOriginalDecl())
-          ->getDefinitionOrSelf();
+  auto *ClassDecl = SrcRecordTy->castAsCXXRecordDecl();
   llvm::Value *OffsetToTop;
   if (CGM.getItaniumVTableContext().isRelativeLayout()) {
     // Get the vtable pointer.
@@ -3869,9 +3863,7 @@ static bool CanUseSingleInheritance(const CXXRecordDecl *RD) {
     return false;
 
   // Check that the class is dynamic iff the base is.
-  auto *BaseDecl = cast<CXXRecordDecl>(
-                       Base->getType()->castAs<RecordType>()->getOriginalDecl())
-                       ->getDefinitionOrSelf();
+  auto *BaseDecl = Base->getType()->castAsCXXRecordDecl();
   if (!BaseDecl->isEmpty() &&
       BaseDecl->isDynamicClass() != RD->isDynamicClass())
     return false;
@@ -4391,10 +4383,7 @@ static unsigned ComputeVMIClassTypeInfoFlags(const CXXBaseSpecifier *Base,
 
   unsigned Flags = 0;
 
-  auto *BaseDecl = cast<CXXRecordDecl>(
-                       Base->getType()->castAs<RecordType>()->getOriginalDecl())
-                       ->getDefinitionOrSelf();
-
+  auto *BaseDecl = Base->getType()->castAsCXXRecordDecl();
   if (Base->isVirtual()) {
     // Mark the virtual base as seen.
     if (!Bases.VirtualBases.insert(BaseDecl).second) {
@@ -4492,11 +4481,7 @@ void ItaniumRTTIBuilder::BuildVMIClassTypeInfo(const CXXRecordDecl *RD) {
     // The __base_type member points to the RTTI for the base type.
     Fields.push_back(ItaniumRTTIBuilder(CXXABI).BuildTypeInfo(Base.getType()));
 
-    auto *BaseDecl =
-        cast<CXXRecordDecl>(
-            Base.getType()->castAs<RecordType>()->getOriginalDecl())
-            ->getDefinitionOrSelf();
-
+    auto *BaseDecl = Base.getType()->castAsCXXRecordDecl();
     int64_t OffsetFlags = 0;
 
     // All but the lower 8 bits of __offset_flags are a signed offset.

diff  --git a/clang/lib/CodeGen/Targets/AArch64.cpp b/clang/lib/CodeGen/Targets/AArch64.cpp
index 1a878222cde18..6bdfdbcf0e89a 100644
--- a/clang/lib/CodeGen/Targets/AArch64.cpp
+++ b/clang/lib/CodeGen/Targets/AArch64.cpp
@@ -374,8 +374,8 @@ ABIArgInfo AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadicFn,
 
   if (!passAsAggregateType(Ty)) {
     // Treat an enum type as its underlying type.
-    if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-      Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+    if (const auto *ED = Ty->getAsEnumDecl())
+      Ty = ED->getIntegerType();
 
     if (const auto *EIT = Ty->getAs<BitIntType>())
       if (EIT->getNumBits() > 128)
@@ -546,9 +546,8 @@ ABIArgInfo AArch64ABIInfo::classifyReturnType(QualType RetTy,
 
   if (!passAsAggregateType(RetTy)) {
     // Treat an enum type as its underlying type.
-    if (const EnumType *EnumTy = RetTy->getAs<EnumType>())
-      RetTy =
-          EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+    if (const auto *ED = RetTy->getAsEnumDecl())
+      RetTy = ED->getIntegerType();
 
     if (const auto *EIT = RetTy->getAs<BitIntType>())
       if (EIT->getNumBits() > 128)

diff  --git a/clang/lib/CodeGen/Targets/ARC.cpp b/clang/lib/CodeGen/Targets/ARC.cpp
index ace524e1976d9..56bbae0eace94 100644
--- a/clang/lib/CodeGen/Targets/ARC.cpp
+++ b/clang/lib/CodeGen/Targets/ARC.cpp
@@ -105,8 +105,8 @@ ABIArgInfo ARCABIInfo::classifyArgumentType(QualType Ty,
   }
 
   // Treat an enum type as its underlying type.
-  if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-    Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = Ty->getAsEnumDecl())
+    Ty = ED->getIntegerType();
 
   auto SizeInRegs = llvm::alignTo(getContext().getTypeSize(Ty), 32) / 32;
 

diff  --git a/clang/lib/CodeGen/Targets/ARM.cpp b/clang/lib/CodeGen/Targets/ARM.cpp
index c87a0ab52557d..d5c86e19fd358 100644
--- a/clang/lib/CodeGen/Targets/ARM.cpp
+++ b/clang/lib/CodeGen/Targets/ARM.cpp
@@ -382,8 +382,8 @@ ABIArgInfo ARMABIInfo::classifyArgumentType(QualType Ty, bool isVariadic,
 
   if (!isAggregateTypeForABI(Ty)) {
     // Treat an enum type as its underlying type.
-    if (const EnumType *EnumTy = Ty->getAs<EnumType>()) {
-      Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+    if (const auto *ED = Ty->getAsEnumDecl()) {
+      Ty = ED->getIntegerType();
     }
 
     if (const auto *EIT = Ty->getAs<BitIntType>())
@@ -592,9 +592,8 @@ ABIArgInfo ARMABIInfo::classifyReturnType(QualType RetTy, bool isVariadic,
 
   if (!isAggregateTypeForABI(RetTy)) {
     // Treat an enum type as its underlying type.
-    if (const EnumType *EnumTy = RetTy->getAs<EnumType>())
-      RetTy =
-          EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+    if (const auto *ED = RetTy->getAsEnumDecl())
+      RetTy = ED->getIntegerType();
 
     if (const auto *EIT = RetTy->getAs<BitIntType>())
       if (EIT->getNumBits() > 64)

diff  --git a/clang/lib/CodeGen/Targets/BPF.cpp b/clang/lib/CodeGen/Targets/BPF.cpp
index 87d50e671d251..3a7af346f1132 100644
--- a/clang/lib/CodeGen/Targets/BPF.cpp
+++ b/clang/lib/CodeGen/Targets/BPF.cpp
@@ -47,8 +47,8 @@ class BPFABIInfo : public DefaultABIInfo {
       }
     }
 
-    if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-      Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+    if (const auto *ED = Ty->getAsEnumDecl())
+      Ty = ED->getIntegerType();
 
     ASTContext &Context = getContext();
     if (const auto *EIT = Ty->getAs<BitIntType>())
@@ -69,9 +69,8 @@ class BPFABIInfo : public DefaultABIInfo {
                                      getDataLayout().getAllocaAddrSpace());
 
     // Treat an enum type as its underlying type.
-    if (const EnumType *EnumTy = RetTy->getAs<EnumType>())
-      RetTy =
-          EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+    if (const auto *ED = RetTy->getAsEnumDecl())
+      RetTy = ED->getIntegerType();
 
     ASTContext &Context = getContext();
     if (const auto *EIT = RetTy->getAs<BitIntType>())

diff  --git a/clang/lib/CodeGen/Targets/CSKY.cpp b/clang/lib/CodeGen/Targets/CSKY.cpp
index 14deaf106e01e..b9254208f912a 100644
--- a/clang/lib/CodeGen/Targets/CSKY.cpp
+++ b/clang/lib/CodeGen/Targets/CSKY.cpp
@@ -115,8 +115,8 @@ ABIArgInfo CSKYABIInfo::classifyArgumentType(QualType Ty, int &ArgGPRsLeft,
 
   if (!isAggregateTypeForABI(Ty)) {
     // Treat an enum type as its underlying type.
-    if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-      Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+    if (const auto *ED = Ty->getAsEnumDecl())
+      Ty = ED->getIntegerType();
 
     // All integral types are promoted to XLen width, unless passed on the
     // stack.

diff  --git a/clang/lib/CodeGen/Targets/Hexagon.cpp b/clang/lib/CodeGen/Targets/Hexagon.cpp
index 0c423429eda4f..97a9300a1b8cd 100644
--- a/clang/lib/CodeGen/Targets/Hexagon.cpp
+++ b/clang/lib/CodeGen/Targets/Hexagon.cpp
@@ -97,8 +97,8 @@ ABIArgInfo HexagonABIInfo::classifyArgumentType(QualType Ty,
                                                 unsigned *RegsLeft) const {
   if (!isAggregateTypeForABI(Ty)) {
     // Treat an enum type as its underlying type.
-    if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-      Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+    if (const auto *ED = Ty->getAsEnumDecl())
+      Ty = ED->getIntegerType();
 
     uint64_t Size = getContext().getTypeSize(Ty);
     if (Size <= 64)
@@ -160,9 +160,8 @@ ABIArgInfo HexagonABIInfo::classifyReturnType(QualType RetTy) const {
 
   if (!isAggregateTypeForABI(RetTy)) {
     // Treat an enum type as its underlying type.
-    if (const EnumType *EnumTy = RetTy->getAs<EnumType>())
-      RetTy =
-          EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+    if (const auto *ED = RetTy->getAsEnumDecl())
+      RetTy = ED->getIntegerType();
 
     if (Size > 64 && RetTy->isBitIntType())
       return getNaturalAlignIndirect(

diff  --git a/clang/lib/CodeGen/Targets/Lanai.cpp b/clang/lib/CodeGen/Targets/Lanai.cpp
index 08cb36034f6fd..675009f0e5748 100644
--- a/clang/lib/CodeGen/Targets/Lanai.cpp
+++ b/clang/lib/CodeGen/Targets/Lanai.cpp
@@ -125,8 +125,8 @@ ABIArgInfo LanaiABIInfo::classifyArgumentType(QualType Ty,
   }
 
   // Treat an enum type as its underlying type.
-  if (const auto *EnumTy = Ty->getAs<EnumType>())
-    Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = Ty->getAsEnumDecl())
+    Ty = ED->getIntegerType();
 
   bool InReg = shouldUseInReg(Ty, State);
 

diff  --git a/clang/lib/CodeGen/Targets/LoongArch.cpp b/clang/lib/CodeGen/Targets/LoongArch.cpp
index af863e6101e2c..0ea08ff533916 100644
--- a/clang/lib/CodeGen/Targets/LoongArch.cpp
+++ b/clang/lib/CodeGen/Targets/LoongArch.cpp
@@ -180,10 +180,7 @@ bool LoongArchABIInfo::detectFARsEligibleStructHelper(
     // If this is a C++ record, check the bases first.
     if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
       for (const CXXBaseSpecifier &B : CXXRD->bases()) {
-        const auto *BDecl =
-            cast<CXXRecordDecl>(
-                B.getType()->castAs<RecordType>()->getOriginalDecl())
-                ->getDefinitionOrSelf();
+        const auto *BDecl = B.getType()->castAsCXXRecordDecl();
         if (!detectFARsEligibleStructHelper(
                 B.getType(), CurOff + Layout.getBaseClassOffset(BDecl),
                 Field1Ty, Field1Off, Field2Ty, Field2Off))
@@ -370,8 +367,8 @@ ABIArgInfo LoongArchABIInfo::classifyArgumentType(QualType Ty, bool IsFixed,
 
   if (!isAggregateTypeForABI(Ty) && !Ty->isVectorType()) {
     // Treat an enum type as its underlying type.
-    if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-      Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+    if (const auto *ED = Ty->getAsEnumDecl())
+      Ty = ED->getIntegerType();
 
     // All integral types are promoted to GRLen width.
     if (Size < GRLen && Ty->isIntegralOrEnumerationType())

diff  --git a/clang/lib/CodeGen/Targets/Mips.cpp b/clang/lib/CodeGen/Targets/Mips.cpp
index e12a34ce07bbe..9c3016fa1f5a5 100644
--- a/clang/lib/CodeGen/Targets/Mips.cpp
+++ b/clang/lib/CodeGen/Targets/Mips.cpp
@@ -241,8 +241,8 @@ MipsABIInfo::classifyArgumentType(QualType Ty, uint64_t &Offset) const {
   }
 
   // Treat an enum type as its underlying type.
-  if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-    Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = Ty->getAsEnumDecl())
+    Ty = ED->getIntegerType();
 
   // Make sure we pass indirectly things that are too large.
   if (const auto *EIT = Ty->getAs<BitIntType>())
@@ -332,8 +332,8 @@ ABIArgInfo MipsABIInfo::classifyReturnType(QualType RetTy) const {
   }
 
   // Treat an enum type as its underlying type.
-  if (const EnumType *EnumTy = RetTy->getAs<EnumType>())
-    RetTy = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = RetTy->getAsEnumDecl())
+    RetTy = ED->getIntegerType();
 
   // Make sure we pass indirectly things that are too large.
   if (const auto *EIT = RetTy->getAs<BitIntType>())

diff  --git a/clang/lib/CodeGen/Targets/NVPTX.cpp b/clang/lib/CodeGen/Targets/NVPTX.cpp
index 106251fb27024..400fa200bf7bc 100644
--- a/clang/lib/CodeGen/Targets/NVPTX.cpp
+++ b/clang/lib/CodeGen/Targets/NVPTX.cpp
@@ -172,8 +172,8 @@ ABIArgInfo NVPTXABIInfo::classifyReturnType(QualType RetTy) const {
     return ABIArgInfo::getDirect();
 
   // Treat an enum type as its underlying type.
-  if (const EnumType *EnumTy = RetTy->getAs<EnumType>())
-    RetTy = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = RetTy->getAsEnumDecl())
+    RetTy = ED->getIntegerType();
 
   return (isPromotableIntegerTypeForABI(RetTy) ? ABIArgInfo::getExtend(RetTy)
                                                : ABIArgInfo::getDirect());
@@ -181,8 +181,8 @@ ABIArgInfo NVPTXABIInfo::classifyReturnType(QualType RetTy) const {
 
 ABIArgInfo NVPTXABIInfo::classifyArgumentType(QualType Ty) const {
   // Treat an enum type as its underlying type.
-  if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-    Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = Ty->getAsEnumDecl())
+    Ty = ED->getIntegerType();
 
   // Return aggregates type as indirect by value
   if (isAggregateTypeForABI(Ty)) {

diff  --git a/clang/lib/CodeGen/Targets/PPC.cpp b/clang/lib/CodeGen/Targets/PPC.cpp
index a297833a57ef4..380e8c06c46f0 100644
--- a/clang/lib/CodeGen/Targets/PPC.cpp
+++ b/clang/lib/CodeGen/Targets/PPC.cpp
@@ -153,8 +153,8 @@ class AIXTargetCodeGenInfo : public TargetCodeGenInfo {
 // extended to 32/64 bits.
 bool AIXABIInfo::isPromotableTypeForABI(QualType Ty) const {
   // Treat an enum type as its underlying type.
-  if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-    Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = Ty->getAsEnumDecl())
+    Ty = ED->getIntegerType();
 
   // Promotable integer types are required to be promoted by the ABI.
   if (getContext().isPromotableIntegerType(Ty))
@@ -705,8 +705,8 @@ class PPC64TargetCodeGenInfo : public TargetCodeGenInfo {
 bool
 PPC64_SVR4_ABIInfo::isPromotableTypeForABI(QualType Ty) const {
   // Treat an enum type as its underlying type.
-  if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-    Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = Ty->getAsEnumDecl())
+    Ty = ED->getIntegerType();
 
   // Promotable integer types are required to be promoted by the ABI.
   if (isPromotableIntegerTypeForABI(Ty))

diff  --git a/clang/lib/CodeGen/Targets/RISCV.cpp b/clang/lib/CodeGen/Targets/RISCV.cpp
index 264595a6e4b76..049f61bb86b1e 100644
--- a/clang/lib/CodeGen/Targets/RISCV.cpp
+++ b/clang/lib/CodeGen/Targets/RISCV.cpp
@@ -264,10 +264,7 @@ bool RISCVABIInfo::detectFPCCEligibleStructHelper(QualType Ty, CharUnits CurOff,
     // If this is a C++ record, check the bases first.
     if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
       for (const CXXBaseSpecifier &B : CXXRD->bases()) {
-        const auto *BDecl =
-            cast<CXXRecordDecl>(
-                B.getType()->castAs<RecordType>()->getOriginalDecl())
-                ->getDefinitionOrSelf();
+        const auto *BDecl = B.getType()->castAsCXXRecordDecl();
         CharUnits BaseOff = Layout.getBaseClassOffset(BDecl);
         bool Ret = detectFPCCEligibleStructHelper(B.getType(), CurOff + BaseOff,
                                                   Field1Ty, Field1Off, Field2Ty,
@@ -680,8 +677,8 @@ ABIArgInfo RISCVABIInfo::classifyArgumentType(QualType Ty, bool IsFixed,
 
   if (!isAggregateTypeForABI(Ty) && !Ty->isVectorType()) {
     // Treat an enum type as its underlying type.
-    if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-      Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+    if (const auto *ED = Ty->getAsEnumDecl())
+      Ty = ED->getIntegerType();
 
     // All integral types are promoted to XLen width
     if (Size < XLen && Ty->isIntegralOrEnumerationType()) {

diff  --git a/clang/lib/CodeGen/Targets/Sparc.cpp b/clang/lib/CodeGen/Targets/Sparc.cpp
index 13a42f66f5843..5f3c15d106eb6 100644
--- a/clang/lib/CodeGen/Targets/Sparc.cpp
+++ b/clang/lib/CodeGen/Targets/Sparc.cpp
@@ -237,8 +237,8 @@ SparcV9ABIInfo::classifyType(QualType Ty, unsigned SizeLimit) const {
         /*ByVal=*/false);
 
   // Treat an enum type as its underlying type.
-  if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-    Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = Ty->getAsEnumDecl())
+    Ty = ED->getIntegerType();
 
   // Integer types smaller than a register are extended.
   if (Size < 64 && Ty->isIntegerType())

diff  --git a/clang/lib/CodeGen/Targets/SystemZ.cpp b/clang/lib/CodeGen/Targets/SystemZ.cpp
index c3f0cba4cbef6..9b6b72b1dcc05 100644
--- a/clang/lib/CodeGen/Targets/SystemZ.cpp
+++ b/clang/lib/CodeGen/Targets/SystemZ.cpp
@@ -145,8 +145,8 @@ class SystemZTargetCodeGenInfo : public TargetCodeGenInfo {
 
 bool SystemZABIInfo::isPromotableIntegerTypeForABI(QualType Ty) const {
   // Treat an enum type as its underlying type.
-  if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-    Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = Ty->getAsEnumDecl())
+    Ty = ED->getIntegerType();
 
   // Promotable integer types are required to be promoted by the ABI.
   if (ABIInfo::isPromotableIntegerTypeForABI(Ty))
@@ -208,10 +208,8 @@ llvm::Type *SystemZABIInfo::getFPArgumentType(QualType Ty,
 }
 
 QualType SystemZABIInfo::GetSingleElementType(QualType Ty) const {
-  const RecordType *RT = Ty->getAs<RecordType>();
-
-  if (RT && RT->isStructureOrClassType()) {
-    const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
+  const auto *RD = Ty->getAsRecordDecl();
+  if (RD && RD->isStructureOrClass()) {
     QualType Found;
 
     // If this is a C++ record, check the bases first.

diff  --git a/clang/lib/CodeGen/Targets/WebAssembly.cpp b/clang/lib/CodeGen/Targets/WebAssembly.cpp
index ac8dcd2a0540a..ebe996a4edd8d 100644
--- a/clang/lib/CodeGen/Targets/WebAssembly.cpp
+++ b/clang/lib/CodeGen/Targets/WebAssembly.cpp
@@ -115,11 +115,9 @@ ABIArgInfo WebAssemblyABIInfo::classifyArgumentType(QualType Ty) const {
       return ABIArgInfo::getDirect(CGT.ConvertType(QualType(SeltTy, 0)));
     // For the experimental multivalue ABI, fully expand all other aggregates
     if (Kind == WebAssemblyABIKind::ExperimentalMV) {
-      const RecordType *RT = Ty->getAs<RecordType>();
-      assert(RT);
+      const auto *RD = Ty->castAsRecordDecl();
       bool HasBitField = false;
-      for (auto *Field :
-           RT->getOriginalDecl()->getDefinitionOrSelf()->fields()) {
+      for (auto *Field : RD->fields()) {
         if (Field->isBitField()) {
           HasBitField = true;
           break;

diff  --git a/clang/lib/CodeGen/Targets/X86.cpp b/clang/lib/CodeGen/Targets/X86.cpp
index f04d24db26a25..386772a96c7ff 100644
--- a/clang/lib/CodeGen/Targets/X86.cpp
+++ b/clang/lib/CodeGen/Targets/X86.cpp
@@ -552,8 +552,8 @@ ABIArgInfo X86_32ABIInfo::classifyReturnType(QualType RetTy,
   }
 
   // Treat an enum type as its underlying type.
-  if (const EnumType *EnumTy = RetTy->getAs<EnumType>())
-    RetTy = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = RetTy->getAsEnumDecl())
+    RetTy = ED->getIntegerType();
 
   if (const auto *EIT = RetTy->getAs<BitIntType>())
     if (EIT->getNumBits() > 64)
@@ -881,9 +881,8 @@ ABIArgInfo X86_32ABIInfo::classifyArgumentType(QualType Ty, CCState &State,
     return ABIArgInfo::getDirect();
   }
 
-
-  if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-    Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = Ty->getAsEnumDecl())
+    Ty = ED->getIntegerType();
 
   bool InReg = shouldPrimitiveUseInReg(Ty, State);
 
@@ -1846,10 +1845,9 @@ void X86_64ABIInfo::classify(QualType Ty, uint64_t OffsetBase, Class &Lo,
     return;
   }
 
-  if (const EnumType *ET = Ty->getAs<EnumType>()) {
+  if (const auto *ED = Ty->getAsEnumDecl()) {
     // Classify the underlying integer type.
-    classify(ET->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType(),
-             OffsetBase, Lo, Hi, isNamedArg);
+    classify(ED->getIntegerType(), OffsetBase, Lo, Hi, isNamedArg);
     return;
   }
 
@@ -2071,11 +2069,7 @@ void X86_64ABIInfo::classify(QualType Ty, uint64_t OffsetBase, Class &Lo,
       for (const auto &I : CXXRD->bases()) {
         assert(!I.isVirtual() && !I.getType()->isDependentType() &&
                "Unexpected base class!");
-        const auto *Base =
-            cast<CXXRecordDecl>(
-                I.getType()->castAs<RecordType>()->getOriginalDecl())
-                ->getDefinitionOrSelf();
-
+        const auto *Base = I.getType()->castAsCXXRecordDecl();
         // Classify this field.
         //
         // AMD64-ABI 3.2.3p2: Rule 3. If the size of the aggregate exceeds a
@@ -2187,8 +2181,8 @@ ABIArgInfo X86_64ABIInfo::getIndirectReturnResult(QualType Ty) const {
   // place naturally.
   if (!isAggregateTypeForABI(Ty)) {
     // Treat an enum type as its underlying type.
-    if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-      Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+    if (const auto *ED = Ty->getAsEnumDecl())
+      Ty = ED->getIntegerType();
 
     if (Ty->isBitIntType())
       return getNaturalAlignIndirect(Ty, getDataLayout().getAllocaAddrSpace());
@@ -2229,8 +2223,8 @@ ABIArgInfo X86_64ABIInfo::getIndirectResult(QualType Ty,
   if (!isAggregateTypeForABI(Ty) && !IsIllegalVectorType(Ty) &&
       !Ty->isBitIntType()) {
     // Treat an enum type as its underlying type.
-    if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-      Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+    if (const auto *ED = Ty->getAsEnumDecl())
+      Ty = ED->getIntegerType();
 
     return (isPromotableIntegerTypeForABI(Ty) ? ABIArgInfo::getExtend(Ty)
                                               : ABIArgInfo::getDirect());
@@ -2358,10 +2352,7 @@ static bool BitsContainNoUserData(QualType Ty, unsigned StartBit,
       for (const auto &I : CXXRD->bases()) {
         assert(!I.isVirtual() && !I.getType()->isDependentType() &&
                "Unexpected base class!");
-        const auto *Base =
-            cast<CXXRecordDecl>(
-                I.getType()->castAs<RecordType>()->getOriginalDecl())
-                ->getDefinitionOrSelf();
+        const auto *Base = I.getType()->castAsCXXRecordDecl();
 
         // If the base is after the span we care about, ignore it.
         unsigned BaseOffset = Context.toBits(Layout.getBaseClassOffset(Base));
@@ -2641,9 +2632,8 @@ ABIArgInfo X86_64ABIInfo::classifyReturnType(QualType RetTy) const {
     // so that the parameter gets the right LLVM IR attributes.
     if (Hi == NoClass && isa<llvm::IntegerType>(ResType)) {
       // Treat an enum type as its underlying type.
-      if (const EnumType *EnumTy = RetTy->getAs<EnumType>())
-        RetTy =
-            EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+      if (const auto *ED = RetTy->getAsEnumDecl())
+        RetTy = ED->getIntegerType();
 
       if (RetTy->isIntegralOrEnumerationType() &&
           isPromotableIntegerTypeForABI(RetTy))
@@ -2792,8 +2782,8 @@ X86_64ABIInfo::classifyArgumentType(QualType Ty, unsigned freeIntRegs,
     // so that the parameter gets the right LLVM IR attributes.
     if (Hi == NoClass && isa<llvm::IntegerType>(ResType)) {
       // Treat an enum type as its underlying type.
-      if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-        Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+      if (const auto *ED = Ty->getAsEnumDecl())
+        Ty = ED->getIntegerType();
 
       if (Ty->isIntegralOrEnumerationType() &&
           isPromotableIntegerTypeForABI(Ty))

diff  --git a/clang/lib/Frontend/Rewrite/RewriteModernObjC.cpp b/clang/lib/Frontend/Rewrite/RewriteModernObjC.cpp
index f57ccd30c59b9..ad8916f817486 100644
--- a/clang/lib/Frontend/Rewrite/RewriteModernObjC.cpp
+++ b/clang/lib/Frontend/Rewrite/RewriteModernObjC.cpp
@@ -3638,8 +3638,7 @@ bool RewriteModernObjC::RewriteObjCFieldDeclType(QualType &Type,
     return RewriteObjCFieldDeclType(ElemTy, Result);
   }
   else if (Type->isRecordType()) {
-    RecordDecl *RD =
-        Type->castAs<RecordType>()->getOriginalDecl()->getDefinitionOrSelf();
+    auto *RD = Type->castAsRecordDecl();
     if (RD->isCompleteDefinition()) {
       if (RD->isStruct())
         Result += "\n\tstruct ";
@@ -3660,28 +3659,26 @@ bool RewriteModernObjC::RewriteObjCFieldDeclType(QualType &Type,
       Result += "\t} ";
       return true;
     }
-  }
-  else if (Type->isEnumeralType()) {
-    EnumDecl *ED =
-        Type->castAs<EnumType>()->getOriginalDecl()->getDefinitionOrSelf();
-    if (ED->isCompleteDefinition()) {
-      Result += "\n\tenum ";
-      Result += ED->getName();
-      if (GlobalDefinedTags.count(ED)) {
-        // Enum is globall defined, use it.
-        Result += " ";
-        return true;
-      }
-
-      Result += " {\n";
-      for (const auto *EC : ED->enumerators()) {
-        Result += "\t"; Result += EC->getName(); Result += " = ";
-        Result += toString(EC->getInitVal(), 10);
-        Result += ",\n";
-      }
-      Result += "\t} ";
+  } else if (auto *ED = Type->getAsEnumDecl();
+             ED && ED->isCompleteDefinition()) {
+    Result += "\n\tenum ";
+    Result += ED->getName();
+    if (GlobalDefinedTags.count(ED)) {
+      // Enum is globall defined, use it.
+      Result += " ";
       return true;
     }
+
+    Result += " {\n";
+    for (const auto *EC : ED->enumerators()) {
+      Result += "\t";
+      Result += EC->getName();
+      Result += " = ";
+      Result += toString(EC->getInitVal(), 10);
+      Result += ",\n";
+    }
+    Result += "\t} ";
+    return true;
   }
 
   Result += "\t";
@@ -5715,10 +5712,7 @@ void RewriteModernObjC::HandleDeclInMainFile(Decl *D) {
           }
         }
       } else if (VD->getType()->isRecordType()) {
-        RecordDecl *RD = VD->getType()
-                             ->castAs<RecordType>()
-                             ->getOriginalDecl()
-                             ->getDefinitionOrSelf();
+        auto *RD = VD->getType()->castAsRecordDecl();
         if (RD->isCompleteDefinition())
           RewriteRecordBody(RD);
       }

diff  --git a/clang/lib/Frontend/Rewrite/RewriteObjC.cpp b/clang/lib/Frontend/Rewrite/RewriteObjC.cpp
index 0242fc1f6e827..b9c025de87739 100644
--- a/clang/lib/Frontend/Rewrite/RewriteObjC.cpp
+++ b/clang/lib/Frontend/Rewrite/RewriteObjC.cpp
@@ -4835,10 +4835,7 @@ void RewriteObjC::HandleDeclInMainFile(Decl *D) {
           }
         }
       } else if (VD->getType()->isRecordType()) {
-        RecordDecl *RD = VD->getType()
-                             ->castAs<RecordType>()
-                             ->getOriginalDecl()
-                             ->getDefinitionOrSelf();
+        auto *RD = VD->getType()->castAsRecordDecl();
         if (RD->isCompleteDefinition())
           RewriteRecordBody(RD);
       }

diff  --git a/clang/lib/Index/IndexTypeSourceInfo.cpp b/clang/lib/Index/IndexTypeSourceInfo.cpp
index 9a699c3c896de..74c6c116b274e 100644
--- a/clang/lib/Index/IndexTypeSourceInfo.cpp
+++ b/clang/lib/Index/IndexTypeSourceInfo.cpp
@@ -61,7 +61,7 @@ class TypeIndexer : public RecursiveASTVisitor<TypeIndexer> {
     SourceLocation Loc = TL.getNameLoc();
     TypedefNameDecl *ND = TL.getDecl();
     if (ND->isTransparentTag()) {
-      TagDecl *Underlying = ND->getUnderlyingType()->getAsTagDecl();
+      auto *Underlying = ND->getUnderlyingType()->castAsTagDecl();
       return IndexCtx.handleReference(Underlying, Loc, Parent,
                                       ParentDC, SymbolRoleSet(), Relations);
     }

diff  --git a/clang/lib/Interpreter/InterpreterValuePrinter.cpp b/clang/lib/Interpreter/InterpreterValuePrinter.cpp
index a329368fba144..54abfa6dbb9d8 100644
--- a/clang/lib/Interpreter/InterpreterValuePrinter.cpp
+++ b/clang/lib/Interpreter/InterpreterValuePrinter.cpp
@@ -101,15 +101,11 @@ static std::string EnumToString(const Value &V) {
   llvm::raw_string_ostream SS(Str);
   ASTContext &Ctx = const_cast<ASTContext &>(V.getASTContext());
 
-  QualType DesugaredTy = V.getType().getDesugaredType(Ctx);
-  const EnumType *EnumTy = DesugaredTy.getNonReferenceType()->getAs<EnumType>();
-  assert(EnumTy && "Fail to cast to enum type");
-
-  EnumDecl *ED = EnumTy->getOriginalDecl()->getDefinitionOrSelf();
   uint64_t Data = V.convertTo<uint64_t>();
   bool IsFirst = true;
-  llvm::APSInt AP = Ctx.MakeIntValue(Data, DesugaredTy);
+  llvm::APSInt AP = Ctx.MakeIntValue(Data, V.getType());
 
+  auto *ED = V.getType()->castAsEnumDecl();
   for (auto I = ED->enumerator_begin(), E = ED->enumerator_end(); I != E; ++I) {
     if (I->getInitVal() == AP) {
       if (!IsFirst)
@@ -665,8 +661,8 @@ __clang_Interpreter_SetValueNoAlloc(void *This, void *OutVal, void *OpaqueType,
   if (VRef.getKind() == Value::K_PtrOrObj) {
     VRef.setPtr(va_arg(args, void *));
   } else {
-    if (const auto *ET = QT->getAs<EnumType>())
-      QT = ET->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+    if (const auto *ED = QT->getAsEnumDecl())
+      QT = ED->getIntegerType();
     switch (QT->castAs<BuiltinType>()->getKind()) {
     default:
       llvm_unreachable("unknown type kind!");

diff  --git a/clang/lib/Interpreter/Value.cpp b/clang/lib/Interpreter/Value.cpp
index 2ef057bc5d0a5..d4c9d51ffcb61 100644
--- a/clang/lib/Interpreter/Value.cpp
+++ b/clang/lib/Interpreter/Value.cpp
@@ -101,8 +101,8 @@ static Value::Kind ConvertQualTypeToKind(const ASTContext &Ctx, QualType QT) {
   if (Ctx.hasSameType(QT, Ctx.VoidTy))
     return Value::K_Void;
 
-  if (const auto *ET = QT->getAs<EnumType>())
-    QT = ET->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = QT->getAsEnumDecl())
+    QT = ED->getIntegerType();
 
   const auto *BT = QT->getAs<BuiltinType>();
   if (!BT || BT->isNullPtrType())

diff  --git a/clang/lib/Sema/SemaAccess.cpp b/clang/lib/Sema/SemaAccess.cpp
index 7ceea76012e0c..17415b4185eff 100644
--- a/clang/lib/Sema/SemaAccess.cpp
+++ b/clang/lib/Sema/SemaAccess.cpp
@@ -1784,10 +1784,7 @@ Sema::AccessResult Sema::CheckMemberOperatorAccess(SourceLocation OpLoc,
   if (!getLangOpts().AccessControl || Found.getAccess() == AS_public)
     return AR_accessible;
 
-  const RecordType *RT = ObjectExpr->getType()->castAs<RecordType>();
-  CXXRecordDecl *NamingClass =
-      cast<CXXRecordDecl>(RT->getOriginalDecl())->getDefinitionOrSelf();
-
+  auto *NamingClass = ObjectExpr->getType()->castAsCXXRecordDecl();
   AccessTarget Entity(Context, AccessTarget::Member, NamingClass, Found,
                       ObjectExpr->getType());
   Entity.setDiag(diag::err_access) << ObjectExpr->getSourceRange() << Range;

diff  --git a/clang/lib/Sema/SemaBPF.cpp b/clang/lib/Sema/SemaBPF.cpp
index 6428435ed9d2a..c5e97484dab78 100644
--- a/clang/lib/Sema/SemaBPF.cpp
+++ b/clang/lib/Sema/SemaBPF.cpp
@@ -99,13 +99,12 @@ static bool isValidPreserveEnumValueArg(Expr *Arg) {
     return false;
 
   // The type must be EnumType.
-  const Type *Ty = ArgType->getUnqualifiedDesugaredType();
-  const auto *ET = Ty->getAs<EnumType>();
-  if (!ET)
+  const auto *ED = ArgType->getAsEnumDecl();
+  if (!ED)
     return false;
 
   // The enum value must be supported.
-  return llvm::is_contained(ET->getOriginalDecl()->enumerators(), Enumerator);
+  return llvm::is_contained(ED->enumerators(), Enumerator);
 }
 
 bool SemaBPF::CheckBPFBuiltinFunctionCall(unsigned BuiltinID,

diff  --git a/clang/lib/Sema/SemaCXXScopeSpec.cpp b/clang/lib/Sema/SemaCXXScopeSpec.cpp
index 45de8ff3ba264..ef14dde5203a6 100644
--- a/clang/lib/Sema/SemaCXXScopeSpec.cpp
+++ b/clang/lib/Sema/SemaCXXScopeSpec.cpp
@@ -133,11 +133,8 @@ DeclContext *Sema::computeDeclContext(const CXXScopeSpec &SS,
     return const_cast<NamespaceDecl *>(
         NNS.getAsNamespaceAndPrefix().Namespace->getNamespace());
 
-  case NestedNameSpecifier::Kind::Type: {
-    auto *TD = NNS.getAsType()->getAsTagDecl();
-    assert(TD && "Non-tag type in nested-name-specifier");
-    return TD;
-  }
+  case NestedNameSpecifier::Kind::Type:
+    return NNS.getAsType()->castAsTagDecl();
 
   case NestedNameSpecifier::Kind::Global:
     return Context.getTranslationUnitDecl();

diff  --git a/clang/lib/Sema/SemaCast.cpp b/clang/lib/Sema/SemaCast.cpp
index 933a6c558e7ac..de22419ee35de 100644
--- a/clang/lib/Sema/SemaCast.cpp
+++ b/clang/lib/Sema/SemaCast.cpp
@@ -1484,8 +1484,7 @@ static TryCastResult TryStaticCast(Sema &Self, ExprResult &SrcExpr,
     if (SrcType->isIntegralOrEnumerationType()) {
       // [expr.static.cast]p10 If the enumeration type has a fixed underlying
       // type, the value is first converted to that type by integral conversion
-      const EnumType *Enum = DestType->castAs<EnumType>();
-      const EnumDecl *ED = Enum->getOriginalDecl()->getDefinitionOrSelf();
+      const auto *ED = DestType->castAsEnumDecl();
       Kind = ED->isFixed() && ED->getIntegerType()->isBooleanType()
                  ? CK_IntegralToBoolean
                  : CK_IntegralCast;

diff  --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index 4cba4108c3500..1a30793f8ca44 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -5278,13 +5278,10 @@ bool Sema::BuiltinVAStart(unsigned BuiltinID, CallExpr *TheCall) {
              // extra checking to see what their promotable type actually is.
              if (!Context.isPromotableIntegerType(Type))
                return false;
-             if (!Type->isEnumeralType())
+             const auto *ED = Type->getAsEnumDecl();
+             if (!ED)
                return true;
-             const EnumDecl *ED = Type->castAs<EnumType>()
-                                      ->getOriginalDecl()
-                                      ->getDefinitionOrSelf();
-             return !(ED &&
-                      Context.typesAreCompatible(ED->getPromotionType(), Type));
+             return !Context.typesAreCompatible(ED->getPromotionType(), Type);
            }()) {
     unsigned Reason = 0;
     if (Type->isReferenceType())  Reason = 1;
@@ -8432,10 +8429,9 @@ CheckPrintfHandler::checkFormatExpr(const analyze_printf::PrintfSpecifier &FS,
   bool IsEnum = false;
   bool IsScopedEnum = false;
   QualType IntendedTy = ExprTy;
-  if (auto EnumTy = ExprTy->getAs<EnumType>()) {
-    IntendedTy =
-        EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
-    if (EnumTy->isUnscopedEnumerationType()) {
+  if (const auto *ED = ExprTy->getAsEnumDecl()) {
+    IntendedTy = ED->getIntegerType();
+    if (!ED->isScoped()) {
       ExprTy = IntendedTy;
       // This controls whether we're talking about the underlying type or not,
       // which we only want to do when it's an unscoped enum.
@@ -9735,10 +9731,7 @@ struct SearchNonTrivialToInitializeField
     S.DiagRuntimeBehavior(SL, E, S.PDiag(diag::note_nontrivial_field) << 1);
   }
   void visitStruct(QualType FT, SourceLocation SL) {
-    for (const FieldDecl *FD : FT->castAs<RecordType>()
-                                   ->getOriginalDecl()
-                                   ->getDefinitionOrSelf()
-                                   ->fields())
+    for (const FieldDecl *FD : FT->castAsRecordDecl()->fields())
       visit(FD->getType(), FD->getLocation());
   }
   void visitArray(QualType::PrimitiveDefaultInitializeKind PDIK,
@@ -9783,10 +9776,7 @@ struct SearchNonTrivialToCopyField
     S.DiagRuntimeBehavior(SL, E, S.PDiag(diag::note_nontrivial_field) << 0);
   }
   void visitStruct(QualType FT, SourceLocation SL) {
-    for (const FieldDecl *FD : FT->castAs<RecordType>()
-                                   ->getOriginalDecl()
-                                   ->getDefinitionOrSelf()
-                                   ->fields())
+    for (const FieldDecl *FD : FT->castAsRecordDecl()->fields())
       visit(FD->getType(), FD->getLocation());
   }
   void visitArray(QualType::PrimitiveCopyKind PCK, const ArrayType *AT,
@@ -10586,20 +10576,15 @@ struct IntRange {
 
     if (!C.getLangOpts().CPlusPlus) {
       // For enum types in C code, use the underlying datatype.
-      if (const auto *ET = dyn_cast<EnumType>(T))
-        T = ET->getOriginalDecl()
-                ->getDefinitionOrSelf()
-                ->getIntegerType()
-                .getDesugaredType(C)
-                .getTypePtr();
-    } else if (const auto *ET = dyn_cast<EnumType>(T)) {
+      if (const auto *ED = T->getAsEnumDecl())
+        T = ED->getIntegerType().getDesugaredType(C).getTypePtr();
+    } else if (auto *Enum = T->getAsEnumDecl()) {
       // For enum types in C++, use the known bit width of the enumerators.
-      EnumDecl *Enum = ET->getOriginalDecl()->getDefinitionOrSelf();
       // In C++11, enums can have a fixed underlying type. Use this type to
       // compute the range.
       if (Enum->isFixed()) {
         return IntRange(C.getIntWidth(QualType(T, 0)),
-                        !ET->isSignedIntegerOrEnumerationType());
+                        !Enum->getIntegerType()->isSignedIntegerType());
       }
 
       unsigned NumPositive = Enum->getNumPositiveBits();
@@ -10635,10 +10620,8 @@ struct IntRange {
       T = CT->getElementType().getTypePtr();
     if (const AtomicType *AT = dyn_cast<AtomicType>(T))
       T = AT->getValueType().getTypePtr();
-    if (const EnumType *ET = dyn_cast<EnumType>(T))
-      T = C.getCanonicalType(
-               ET->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType())
-              .getTypePtr();
+    if (const auto *ED = T->getAsEnumDecl())
+      T = C.getCanonicalType(ED->getIntegerType()).getTypePtr();
 
     if (const auto *EIT = dyn_cast<BitIntType>(T))
       return IntRange(EIT->getNumBits(), EIT->isUnsigned());
@@ -11609,10 +11592,7 @@ static bool AnalyzeBitFieldAssignment(Sema &S, FieldDecl *Bitfield, Expr *Init,
   if (BitfieldType->isBooleanType())
      return false;
 
-  if (BitfieldType->isEnumeralType()) {
-    EnumDecl *BitfieldEnumDecl = BitfieldType->castAs<EnumType>()
-                                     ->getOriginalDecl()
-                                     ->getDefinitionOrSelf();
+  if (auto *BitfieldEnumDecl = BitfieldType->getAsEnumDecl()) {
     // If the underlying enum type was not explicitly specified as an unsigned
     // type and the enum contain only positive values, MSVC++ will cause an
     // inconsistency by storing this as a signed type.
@@ -11641,15 +11621,14 @@ static bool AnalyzeBitFieldAssignment(Sema &S, FieldDecl *Bitfield, Expr *Init,
     // The RHS is not constant.  If the RHS has an enum type, make sure the
     // bitfield is wide enough to hold all the values of the enum without
     // truncation.
-    const auto *EnumTy = OriginalInit->getType()->getAs<EnumType>();
+    const auto *ED = OriginalInit->getType()->getAsEnumDecl();
     const PreferredTypeAttr *PTAttr = nullptr;
-    if (!EnumTy) {
+    if (!ED) {
       PTAttr = Bitfield->getAttr<PreferredTypeAttr>();
       if (PTAttr)
-        EnumTy = PTAttr->getType()->getAs<EnumType>();
+        ED = PTAttr->getType()->getAsEnumDecl();
     }
-    if (EnumTy) {
-      EnumDecl *ED = EnumTy->getOriginalDecl()->getDefinitionOrSelf();
+    if (ED) {
       bool SignedBitfield = BitfieldType->isSignedIntegerOrEnumerationType();
 
       // Enum types are implicitly signed on Windows, so check if there are any
@@ -15354,17 +15333,14 @@ static bool isLayoutCompatible(const ASTContext &C, QualType T1, QualType T2) {
   if (TC1 != TC2)
     return false;
 
-  if (TC1 == Type::Enum) {
-    return isLayoutCompatible(
-        C, cast<EnumType>(T1)->getOriginalDecl()->getDefinitionOrSelf(),
-        cast<EnumType>(T2)->getOriginalDecl()->getDefinitionOrSelf());
-  } else if (TC1 == Type::Record) {
+  if (TC1 == Type::Enum)
+    return isLayoutCompatible(C, T1->castAsEnumDecl(), T2->castAsEnumDecl());
+  if (TC1 == Type::Record) {
     if (!T1->isStandardLayoutType() || !T2->isStandardLayoutType())
       return false;
 
-    return isLayoutCompatible(
-        C, cast<RecordType>(T1)->getOriginalDecl()->getDefinitionOrSelf(),
-        cast<RecordType>(T2)->getOriginalDecl()->getDefinitionOrSelf());
+    return isLayoutCompatible(C, T1->castAsRecordDecl(),
+                              T2->castAsRecordDecl());
   }
 
   return false;
@@ -15721,9 +15697,7 @@ void Sema::RefersToMemberWithReducedAlignment(
       return;
     if (ME->isArrow())
       BaseType = BaseType->getPointeeType();
-    RecordDecl *RD = BaseType->castAs<RecordType>()
-                         ->getOriginalDecl()
-                         ->getDefinitionOrSelf();
+    auto *RD = BaseType->castAsRecordDecl();
     if (RD->isInvalidDecl())
       return;
 

diff  --git a/clang/lib/Sema/SemaCodeComplete.cpp b/clang/lib/Sema/SemaCodeComplete.cpp
index c6adead5a65f5..03bf4b3690b13 100644
--- a/clang/lib/Sema/SemaCodeComplete.cpp
+++ b/clang/lib/Sema/SemaCodeComplete.cpp
@@ -5113,10 +5113,7 @@ void SemaCodeCompletion::CodeCompleteExpression(
     PreferredTypeIsPointer = Data.PreferredType->isAnyPointerType() ||
                              Data.PreferredType->isMemberPointerType() ||
                              Data.PreferredType->isBlockPointerType();
-    if (Data.PreferredType->isEnumeralType()) {
-      EnumDecl *Enum = Data.PreferredType->castAs<EnumType>()
-                           ->getOriginalDecl()
-                           ->getDefinitionOrSelf();
+    if (auto *Enum = Data.PreferredType->getAsEnumDecl()) {
       // FIXME: collect covered enumerators in cases like:
       //        if (x == my_enum::one) { ... } else if (x == ^) {}
       AddEnumerators(Results, getASTContext(), Enum, SemaRef.CurContext,
@@ -6240,18 +6237,14 @@ void SemaCodeCompletion::CodeCompleteCase(Scope *S) {
   if (!Switch->getCond())
     return;
   QualType type = Switch->getCond()->IgnoreImplicit()->getType();
-  if (!type->isEnumeralType()) {
+  EnumDecl *Enum = type->getAsEnumDecl();
+  if (!Enum) {
     CodeCompleteExpressionData Data(type);
     Data.IntegralConstantExpression = true;
     CodeCompleteExpression(S, Data);
     return;
   }
 
-  // Code-complete the cases of a switch statement over an enumeration type
-  // by providing the list of
-  EnumDecl *Enum =
-      type->castAs<EnumType>()->getOriginalDecl()->getDefinitionOrSelf();
-
   // Determine which enumerators we have already seen in the switch statement.
   // FIXME: Ideally, we would also be able to look *past* the code-completion
   // token, in case we are code-completing in the middle of the switch and not

diff  --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 8a5b9fc19cd1f..a825fdc1748c6 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -9811,8 +9811,7 @@ static void checkIsValidOpenCLKernelParameter(
 
   // At this point we already handled everything except of a RecordType.
   assert(PT->isRecordType() && "Unexpected type.");
-  const RecordDecl *PD =
-      PT->castAs<RecordType>()->getOriginalDecl()->getDefinitionOrSelf();
+  const auto *PD = PT->castAsRecordDecl();
   VisitStack.push_back(PD);
   assert(VisitStack.back() && "First decl null?");
 
@@ -9840,9 +9839,7 @@ static void checkIsValidOpenCLKernelParameter(
              "Unexpected type.");
       const Type *FieldRecTy = FieldTy->getPointeeOrArrayElementType();
 
-      RD = FieldRecTy->castAs<RecordType>()
-               ->getOriginalDecl()
-               ->getDefinitionOrSelf();
+      RD = FieldRecTy->castAsRecordDecl();
     } else {
       RD = cast<RecordDecl>(Next);
     }
@@ -13370,8 +13367,7 @@ struct DiagNonTrivalCUnionDefaultInitializeVisitor
   }
 
   void visitStruct(QualType QT, const FieldDecl *FD, bool InNonTrivialUnion) {
-    const RecordDecl *RD =
-        QT->castAs<RecordType>()->getOriginalDecl()->getDefinitionOrSelf();
+    const auto *RD = QT->castAsRecordDecl();
     if (RD->isUnion()) {
       if (OrigLoc.isValid()) {
         bool IsUnion = false;
@@ -13437,8 +13433,7 @@ struct DiagNonTrivalCUnionDestructedTypeVisitor
   }
 
   void visitStruct(QualType QT, const FieldDecl *FD, bool InNonTrivialUnion) {
-    const RecordDecl *RD =
-        QT->castAs<RecordType>()->getOriginalDecl()->getDefinitionOrSelf();
+    const auto *RD = QT->castAsRecordDecl();
     if (RD->isUnion()) {
       if (OrigLoc.isValid()) {
         bool IsUnion = false;
@@ -13503,8 +13498,7 @@ struct DiagNonTrivalCUnionCopyVisitor
   }
 
   void visitStruct(QualType QT, const FieldDecl *FD, bool InNonTrivialUnion) {
-    const RecordDecl *RD =
-        QT->castAs<RecordType>()->getOriginalDecl()->getDefinitionOrSelf();
+    const auto *RD = QT->castAsRecordDecl();
     if (RD->isUnion()) {
       if (OrigLoc.isValid()) {
         bool IsUnion = false;

diff  --git a/clang/lib/Sema/SemaDeclCXX.cpp b/clang/lib/Sema/SemaDeclCXX.cpp
index 34e46fba8250f..5dc0a7eb4c29d 100644
--- a/clang/lib/Sema/SemaDeclCXX.cpp
+++ b/clang/lib/Sema/SemaDeclCXX.cpp
@@ -2187,10 +2187,7 @@ static bool CheckConstexprCtorInitializer(Sema &SemaRef,
       return false;
     }
   } else if (Field->isAnonymousStructOrUnion()) {
-    const RecordDecl *RD = Field->getType()
-                               ->castAs<RecordType>()
-                               ->getOriginalDecl()
-                               ->getDefinitionOrSelf();
+    const auto *RD = Field->getType()->castAsRecordDecl();
     for (auto *I : RD->fields())
       // If an anonymous union contains an anonymous struct of which any member
       // is initialized, all members must be initialized.
@@ -10471,11 +10468,7 @@ struct FindHiddenVirtualMethod {
   /// method overloads virtual methods in a base class without overriding any,
   /// to be used with CXXRecordDecl::lookupInBases().
   bool operator()(const CXXBaseSpecifier *Specifier, CXXBasePath &Path) {
-    RecordDecl *BaseRecord = Specifier->getType()
-                                 ->castAs<RecordType>()
-                                 ->getOriginalDecl()
-                                 ->getDefinitionOrSelf();
-
+    auto *BaseRecord = Specifier->getType()->castAsRecordDecl();
     DeclarationName Name = Method->getDeclName();
     assert(Name.getNameKind() == DeclarationName::Identifier);
 
@@ -12651,15 +12644,12 @@ Decl *Sema::ActOnUsingEnumDeclaration(Scope *S, AccessSpecifier AS,
     return nullptr;
   }
 
-  auto *Enum = dyn_cast_if_present<EnumDecl>(EnumTy->getAsTagDecl());
+  auto *Enum = EnumTy->getAsEnumDecl();
   if (!Enum) {
     Diag(IdentLoc, diag::err_using_enum_not_enum) << EnumTy;
     return nullptr;
   }
 
-  if (auto *Def = Enum->getDefinition())
-    Enum = Def;
-
   if (TSI == nullptr)
     TSI = Context.getTrivialTypeSourceInfo(EnumTy, IdentLoc);
 
@@ -19152,9 +19142,7 @@ void Sema::MarkVirtualMembersReferenced(SourceLocation Loc,
     return;
 
   for (const auto &I : RD->bases()) {
-    const auto *Base = cast<CXXRecordDecl>(
-                           I.getType()->castAs<RecordType>()->getOriginalDecl())
-                           ->getDefinitionOrSelf();
+    const auto *Base = I.getType()->castAsCXXRecordDecl();
     if (Base->getNumVBases() == 0)
       continue;
     MarkVirtualMembersReferenced(Loc, Base);

diff  --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index 6e5cc7837ecc1..6797353db14bf 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -12382,10 +12382,7 @@ static QualType checkArithmeticOrEnumeralThreeWayCompare(Sema &S,
       S.InvalidOperands(Loc, LHS, RHS);
       return QualType();
     }
-    QualType IntType = LHSStrippedType->castAs<EnumType>()
-                           ->getOriginalDecl()
-                           ->getDefinitionOrSelf()
-                           ->getIntegerType();
+    QualType IntType = LHSStrippedType->castAsEnumDecl()->getIntegerType();
     assert(IntType->isArithmeticType());
 
     // We can't use `CK_IntegralCast` when the underlying type is 'bool', so we
@@ -16882,9 +16879,8 @@ ExprResult Sema::BuildVAArgExpr(SourceLocation BuiltinLoc,
       // va_arg. Instead, get the underlying type of the enumeration and pass
       // that.
       QualType UnderlyingType = TInfo->getType();
-      if (const auto *ET = UnderlyingType->getAs<EnumType>())
-        UnderlyingType =
-            ET->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+      if (const auto *ED = UnderlyingType->getAsEnumDecl())
+        UnderlyingType = ED->getIntegerType();
       if (Context.typesAreCompatible(PromoteType, UnderlyingType,
                                      /*CompareUnqualified*/ true))
         PromoteType = QualType();

diff  --git a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/SemaExprCXX.cpp
index b257df438c019..763fc0747eb82 100644
--- a/clang/lib/Sema/SemaExprCXX.cpp
+++ b/clang/lib/Sema/SemaExprCXX.cpp
@@ -3052,9 +3052,7 @@ bool Sema::FindAllocationFunctions(
   LookupResult FoundDelete(*this, DeleteName, StartLoc, LookupOrdinaryName);
   if (AllocElemType->isRecordType() &&
       DeleteScope != AllocationFunctionScope::Global) {
-    auto *RD = cast<CXXRecordDecl>(
-                   AllocElemType->castAs<RecordType>()->getOriginalDecl())
-                   ->getDefinitionOrSelf();
+    auto *RD = AllocElemType->castAsCXXRecordDecl();
     LookupQualifiedName(FoundDelete, RD);
   }
   if (FoundDelete.isAmbiguous())
@@ -4835,10 +4833,7 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
     if (FromType->isVectorType() || ToType->isVectorType())
       StepTy = adjustVectorType(Context, FromType, ToType, &ElTy);
     if (ElTy->isBooleanType()) {
-      assert(FromType->castAs<EnumType>()
-                 ->getOriginalDecl()
-                 ->getDefinitionOrSelf()
-                 ->isFixed() &&
+      assert(FromType->castAsEnumDecl()->isFixed() &&
              SCS.Second == ICK_Integral_Promotion &&
              "only enums with fixed underlying type can promote to bool");
       From = ImpCastExprToType(From, StepTy, CK_IntegralToBoolean, VK_PRValue,
@@ -7523,12 +7518,10 @@ ExprResult Sema::IgnoredValueConversions(Expr *E) {
   }
 
   // GCC seems to also exclude expressions of incomplete enum type.
-  if (const EnumType *T = E->getType()->getAs<EnumType>()) {
-    if (!T->getOriginalDecl()->getDefinitionOrSelf()->isComplete()) {
-      // FIXME: stupid workaround for a codegen bug!
-      E = ImpCastExprToType(E, Context.VoidTy, CK_ToVoid).get();
-      return E;
-    }
+  if (const auto *ED = E->getType()->getAsEnumDecl(); ED && !ED->isComplete()) {
+    // FIXME: stupid workaround for a codegen bug!
+    E = ImpCastExprToType(E, Context.VoidTy, CK_ToVoid).get();
+    return E;
   }
 
   ExprResult Res = DefaultFunctionArrayLvalueConversion(E);

diff  --git a/clang/lib/Sema/SemaExprObjC.cpp b/clang/lib/Sema/SemaExprObjC.cpp
index 03b5c79cf70e3..5684e53a6bcaa 100644
--- a/clang/lib/Sema/SemaExprObjC.cpp
+++ b/clang/lib/Sema/SemaExprObjC.cpp
@@ -638,8 +638,7 @@ ExprResult SemaObjC::BuildObjCBoxedExpr(SourceRange SR, Expr *ValueExpr) {
     // Look for the appropriate method within NSNumber.
     BoxingMethod = getNSNumberFactoryMethod(*this, Loc, ValueType);
     BoxedType = NSNumberPointer;
-  } else if (const EnumType *ET = ValueType->getAs<EnumType>()) {
-    const EnumDecl *ED = ET->getOriginalDecl()->getDefinitionOrSelf();
+  } else if (const auto *ED = ValueType->getAsEnumDecl()) {
     if (!ED->isComplete()) {
       Diag(Loc, diag::err_objc_incomplete_boxed_expression_type)
         << ValueType << ValueExpr->getSourceRange();

diff  --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index b89aa9793c98c..6a68fa2ed7a8b 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -386,14 +386,14 @@ static bool requiresImplicitBufferLayoutStructure(const CXXRecordDecl *RD) {
     QualType Ty = Field->getType();
     if (isInvalidConstantBufferLeafElementType(Ty.getTypePtr()))
       return true;
-    if (Ty->isRecordType() &&
-        requiresImplicitBufferLayoutStructure(Ty->getAsCXXRecordDecl()))
+    if (const auto *RD = Ty->getAsCXXRecordDecl();
+        RD && requiresImplicitBufferLayoutStructure(RD))
       return true;
   }
   // check bases
   for (const CXXBaseSpecifier &Base : RD->bases())
     if (requiresImplicitBufferLayoutStructure(
-            Base.getType()->getAsCXXRecordDecl()))
+            Base.getType()->castAsCXXRecordDecl()))
       return true;
   return false;
 }
@@ -509,7 +509,7 @@ static CXXRecordDecl *createHostLayoutStruct(Sema &S,
     assert(NumBases == 1 && "HLSL supports only one base type");
     (void)NumBases;
     CXXBaseSpecifier Base = *StructDecl->bases_begin();
-    CXXRecordDecl *BaseDecl = Base.getType()->getAsCXXRecordDecl();
+    CXXRecordDecl *BaseDecl = Base.getType()->castAsCXXRecordDecl();
     if (requiresImplicitBufferLayoutStructure(BaseDecl)) {
       BaseDecl = createHostLayoutStruct(S, BaseDecl);
       if (BaseDecl) {
@@ -3979,19 +3979,19 @@ class InitListTransformer {
       return true;
     }
 
-    if (auto *RTy = Ty->getAs<RecordType>()) {
-      llvm::SmallVector<const RecordType *> RecordTypes;
-      RecordTypes.push_back(RTy);
-      while (RecordTypes.back()->getAsCXXRecordDecl()->getNumBases()) {
-        CXXRecordDecl *D = RecordTypes.back()->getAsCXXRecordDecl();
+    if (auto *RD = Ty->getAsCXXRecordDecl()) {
+      llvm::SmallVector<CXXRecordDecl *> RecordDecls;
+      RecordDecls.push_back(RD);
+      while (RecordDecls.back()->getNumBases()) {
+        CXXRecordDecl *D = RecordDecls.back();
         assert(D->getNumBases() == 1 &&
                "HLSL doesn't support multiple inheritance");
-        RecordTypes.push_back(D->bases_begin()->getType()->getAs<RecordType>());
+        RecordDecls.push_back(
+            D->bases_begin()->getType()->castAsCXXRecordDecl());
       }
-      while (!RecordTypes.empty()) {
-        const RecordType *RT = RecordTypes.pop_back_val();
-        for (auto *FD :
-             RT->getOriginalDecl()->getDefinitionOrSelf()->fields()) {
+      while (!RecordDecls.empty()) {
+        CXXRecordDecl *RD = RecordDecls.pop_back_val();
+        for (auto *FD : RD->fields()) {
           DeclAccessPair Found = DeclAccessPair::make(FD, FD->getAccess());
           DeclarationNameInfo NameInfo(FD->getDeclName(), E->getBeginLoc());
           ExprResult Res = S.BuildFieldReferenceExpr(
@@ -4028,21 +4028,20 @@ class InitListTransformer {
       for (uint64_t I = 0; I < Size; ++I)
         Inits.push_back(generateInitListsImpl(ElTy));
     }
-    if (auto *RTy = Ty->getAs<RecordType>()) {
-      llvm::SmallVector<const RecordType *> RecordTypes;
-      RecordTypes.push_back(RTy);
-      while (RecordTypes.back()->getAsCXXRecordDecl()->getNumBases()) {
-        CXXRecordDecl *D = RecordTypes.back()->getAsCXXRecordDecl();
+    if (auto *RD = Ty->getAsCXXRecordDecl()) {
+      llvm::SmallVector<CXXRecordDecl *> RecordDecls;
+      RecordDecls.push_back(RD);
+      while (RecordDecls.back()->getNumBases()) {
+        CXXRecordDecl *D = RecordDecls.back();
         assert(D->getNumBases() == 1 &&
                "HLSL doesn't support multiple inheritance");
-        RecordTypes.push_back(D->bases_begin()->getType()->getAs<RecordType>());
+        RecordDecls.push_back(
+            D->bases_begin()->getType()->castAsCXXRecordDecl());
       }
-      while (!RecordTypes.empty()) {
-        const RecordType *RT = RecordTypes.pop_back_val();
-        for (auto *FD :
-             RT->getOriginalDecl()->getDefinitionOrSelf()->fields()) {
+      while (!RecordDecls.empty()) {
+        CXXRecordDecl *RD = RecordDecls.pop_back_val();
+        for (auto *FD : RD->fields())
           Inits.push_back(generateInitListsImpl(FD->getType()));
-        }
       }
     }
     auto *NewInit = new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(),

diff  --git a/clang/lib/Sema/SemaInit.cpp b/clang/lib/Sema/SemaInit.cpp
index 0242671aecdb3..4dfeebdea7919 100644
--- a/clang/lib/Sema/SemaInit.cpp
+++ b/clang/lib/Sema/SemaInit.cpp
@@ -1125,8 +1125,7 @@ int InitListChecker::numArrayElements(QualType DeclType) {
 }
 
 int InitListChecker::numStructUnionElements(QualType DeclType) {
-  RecordDecl *structDecl =
-      DeclType->castAs<RecordType>()->getOriginalDecl()->getDefinitionOrSelf();
+  auto *structDecl = DeclType->castAsRecordDecl();
   int InitializableMembers = 0;
   if (auto *CXXRD = dyn_cast<CXXRecordDecl>(structDecl))
     InitializableMembers += CXXRD->getNumBases();
@@ -1155,22 +1154,14 @@ static bool isIdiomaticBraceElisionEntity(const InitializedEntity &Entity) {
 
   // Allows elide brace initialization for aggregates with empty base.
   if (Entity.getKind() == InitializedEntity::EK_Base) {
-    auto *ParentRD = Entity.getParent()
-                         ->getType()
-                         ->castAs<RecordType>()
-                         ->getOriginalDecl()
-                         ->getDefinitionOrSelf();
+    auto *ParentRD = Entity.getParent()->getType()->castAsRecordDecl();
     CXXRecordDecl *CXXRD = cast<CXXRecordDecl>(ParentRD);
     return CXXRD->getNumBases() == 1 && CXXRD->field_empty();
   }
 
   // Allow brace elision if the only subobject is a field.
   if (Entity.getKind() == InitializedEntity::EK_Member) {
-    auto *ParentRD = Entity.getParent()
-                         ->getType()
-                         ->castAs<RecordType>()
-                         ->getOriginalDecl()
-                         ->getDefinitionOrSelf();
+    auto *ParentRD = Entity.getParent()->getType()->castAsRecordDecl();
     if (CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(ParentRD)) {
       if (CXXRD->getNumBases()) {
         return false;
@@ -2347,7 +2338,9 @@ void InitListChecker::CheckStructUnionTypes(
            Field != FieldEnd; ++Field) {
         if (Field->hasInClassInitializer() ||
             (Field->isAnonymousStructOrUnion() &&
-             Field->getType()->getAsCXXRecordDecl()->hasInClassInitializer())) {
+             Field->getType()
+                 ->castAsCXXRecordDecl()
+                 ->hasInClassInitializer())) {
           StructuredList->setInitializedFieldInUnion(*Field);
           // FIXME: Actually build a CXXDefaultInitExpr?
           return;
@@ -4534,11 +4527,7 @@ static void TryConstructorInitialization(Sema &S,
     }
   }
 
-  const RecordType *DestRecordType = DestType->getAs<RecordType>();
-  assert(DestRecordType && "Constructor initialization requires record type");
-  auto *DestRecordDecl = cast<CXXRecordDecl>(DestRecordType->getOriginalDecl())
-                             ->getDefinitionOrSelf();
-
+  auto *DestRecordDecl = DestType->castAsCXXRecordDecl();
   // Build the candidate set directly in the initialization sequence
   // structure, so that it will persist if we fail.
   OverloadCandidateSet &CandidateSet = Sequence.getFailedCandidateSet();
@@ -5025,7 +5014,7 @@ static void TryListInitialization(Sema &S,
       //     class type with a default constructor, the object is
       //     value-initialized.
       if (InitList->getNumInits() == 0) {
-        CXXRecordDecl *RD = DestType->getAsCXXRecordDecl();
+        CXXRecordDecl *RD = DestType->castAsCXXRecordDecl();
         if (S.LookupDefaultConstructor(RD)) {
           TryValueInitialization(S, Entity, Kind, Sequence, InitList);
           return;
@@ -5056,10 +5045,9 @@ static void TryListInitialization(Sema &S,
     //     is direct-list-initialization, the object is initialized with the
     //     value T(v); if a narrowing conversion is required to convert v to
     //     the underlying type of T, the program is ill-formed.
-    auto *ET = DestType->getAs<EnumType>();
     if (S.getLangOpts().CPlusPlus17 &&
-        Kind.getKind() == InitializationKind::IK_DirectList && ET &&
-        ET->getOriginalDecl()->getDefinitionOrSelf()->isFixed() &&
+        Kind.getKind() == InitializationKind::IK_DirectList &&
+        DestType->isEnumeralType() && DestType->castAsEnumDecl()->isFixed() &&
         !S.Context.hasSameUnqualifiedType(E->getType(), DestType) &&
         (E->getType()->isIntegralOrUnscopedEnumerationType() ||
          E->getType()->isFloatingType())) {
@@ -5164,14 +5152,13 @@ static OverloadingResult TryRefInitWithConversionFunction(
   bool AllowExplicitCtors = false;
   bool AllowExplicitConvs = Kind.allowExplicitConversionFunctionsInRefBinding();
 
-  const RecordType *T1RecordType = nullptr;
-  if (AllowRValues && (T1RecordType = T1->getAs<RecordType>()) &&
+  if (AllowRValues && T1->isRecordType() &&
       S.isCompleteType(Kind.getLocation(), T1)) {
+    auto *T1RecordDecl = T1->castAsCXXRecordDecl();
+    if (T1RecordDecl->isInvalidDecl())
+      return OR_No_Viable_Function;
     // The type we're converting to is a class type. Enumerate its constructors
     // to see if there is a suitable conversion.
-    auto *T1RecordDecl = cast<CXXRecordDecl>(T1RecordType->getOriginalDecl())
-                             ->getDefinitionOrSelf();
-
     for (NamedDecl *D : S.LookupConstructors(T1RecordDecl)) {
       auto Info = getConstructorInfo(D);
       if (!Info.Constructor)
@@ -5193,18 +5180,13 @@ static OverloadingResult TryRefInitWithConversionFunction(
       }
     }
   }
-  if (T1RecordType &&
-      T1RecordType->getOriginalDecl()->getDefinitionOrSelf()->isInvalidDecl())
-    return OR_No_Viable_Function;
 
-  const RecordType *T2RecordType = nullptr;
-  if ((T2RecordType = T2->getAs<RecordType>()) &&
-      S.isCompleteType(Kind.getLocation(), T2)) {
+  if (T2->isRecordType() && S.isCompleteType(Kind.getLocation(), T2)) {
+    const auto *T2RecordDecl = T2->castAsCXXRecordDecl();
+    if (T2RecordDecl->isInvalidDecl())
+      return OR_No_Viable_Function;
     // The type we're converting from is a class type, enumerate its conversion
     // functions.
-    auto *T2RecordDecl = cast<CXXRecordDecl>(T2RecordType->getOriginalDecl())
-                             ->getDefinitionOrSelf();
-
     const auto &Conversions = T2RecordDecl->getVisibleConversionFunctions();
     for (auto I = Conversions.begin(), E = Conversions.end(); I != E; ++I) {
       NamedDecl *D = *I;
@@ -5239,9 +5221,6 @@ static OverloadingResult TryRefInitWithConversionFunction(
       }
     }
   }
-  if (T2RecordType &&
-      T2RecordType->getOriginalDecl()->getDefinitionOrSelf()->isInvalidDecl())
-    return OR_No_Viable_Function;
 
   SourceLocation DeclLoc = Initializer->getBeginLoc();
 
@@ -6099,15 +6078,12 @@ static void TryUserDefinedConversion(Sema &S,
   // explicit conversion operators.
   bool AllowExplicit = Kind.AllowExplicit();
 
-  if (const RecordType *DestRecordType = DestType->getAs<RecordType>()) {
+  if (DestType->isRecordType()) {
     // The type we're converting to is a class type. Enumerate its constructors
     // to see if there is a suitable conversion.
-    auto *DestRecordDecl =
-        cast<CXXRecordDecl>(DestRecordType->getOriginalDecl())
-            ->getDefinitionOrSelf();
-
     // Try to complete the type we're converting to.
     if (S.isCompleteType(Kind.getLocation(), DestType)) {
+      auto *DestRecordDecl = DestType->castAsCXXRecordDecl();
       for (NamedDecl *D : S.LookupConstructors(DestRecordDecl)) {
         auto Info = getConstructorInfo(D);
         if (!Info.Constructor)
@@ -6133,17 +6109,14 @@ static void TryUserDefinedConversion(Sema &S,
 
   SourceLocation DeclLoc = Initializer->getBeginLoc();
 
-  if (const RecordType *SourceRecordType = SourceType->getAs<RecordType>()) {
+  if (SourceType->isRecordType()) {
     // The type we're converting from is a class type, enumerate its conversion
     // functions.
 
     // We can only enumerate the conversion functions for a complete type; if
     // the type isn't complete, simply skip this step.
     if (S.isCompleteType(DeclLoc, SourceType)) {
-      auto *SourceRecordDecl =
-          cast<CXXRecordDecl>(SourceRecordType->getOriginalDecl())
-              ->getDefinitionOrSelf();
-
+      auto *SourceRecordDecl = SourceType->castAsCXXRecordDecl();
       const auto &Conversions =
           SourceRecordDecl->getVisibleConversionFunctions();
       for (auto I = Conversions.begin(), E = Conversions.end(); I != E; ++I) {
@@ -8158,10 +8131,8 @@ ExprResult InitializationSequence::Perform(Sema &S,
         // FIXME: It makes no sense to do this here. This should happen
         // regardless of how we initialized the entity.
         QualType T = CurInit.get()->getType();
-        if (const RecordType *Record = T->getAs<RecordType>()) {
-          CXXDestructorDecl *Destructor =
-              S.LookupDestructor(cast<CXXRecordDecl>(Record->getOriginalDecl())
-                                     ->getDefinitionOrSelf());
+        if (auto *Record = T->castAsCXXRecordDecl()) {
+          CXXDestructorDecl *Destructor = S.LookupDestructor(Record);
           S.CheckDestructorAccess(CurInit.get()->getBeginLoc(), Destructor,
                                   S.PDiag(diag::err_access_dtor_temp) << T);
           S.MarkFunctionReferenced(CurInit.get()->getBeginLoc(), Destructor);
@@ -8549,9 +8520,8 @@ ExprResult InitializationSequence::Perform(Sema &S,
             S.isStdInitializerList(Step->Type, &ElementType);
         assert(IsStdInitializerList &&
                "StdInitializerList step to non-std::initializer_list");
-        const CXXRecordDecl *Record =
-            Step->Type->getAsCXXRecordDecl()->getDefinition();
-        assert(Record && Record->isCompleteDefinition() &&
+        const auto *Record = Step->Type->castAsCXXRecordDecl();
+        assert(Record->isCompleteDefinition() &&
                "std::initializer_list should have already be "
                "complete/instantiated by this point");
 
@@ -9214,11 +9184,8 @@ bool InitializationSequence::Diagnose(Sema &S,
                 << S.Context.getCanonicalTagType(Constructor->getParent())
                 << /*base=*/0 << Entity.getType() << InheritedFrom;
 
-            RecordDecl *BaseDecl = Entity.getBaseSpecifier()
-                                       ->getType()
-                                       ->castAs<RecordType>()
-                                       ->getOriginalDecl()
-                                       ->getDefinitionOrSelf();
+            auto *BaseDecl =
+                Entity.getBaseSpecifier()->getType()->castAsRecordDecl();
             S.Diag(BaseDecl->getLocation(), diag::note_previous_decl)
                 << S.Context.getCanonicalTagType(BaseDecl);
           } else {

diff  --git a/clang/lib/Sema/SemaLambda.cpp b/clang/lib/Sema/SemaLambda.cpp
index 8a81cbe2623d8..fbc2e7eb30676 100644
--- a/clang/lib/Sema/SemaLambda.cpp
+++ b/clang/lib/Sema/SemaLambda.cpp
@@ -641,9 +641,8 @@ static EnumDecl *findEnumForBlockReturn(Expr *E) {
   }
 
   //   - it is an expression of that formal enum type.
-  if (const EnumType *ET = E->getType()->getAs<EnumType>()) {
-    return ET->getOriginalDecl()->getDefinitionOrSelf();
-  }
+  if (auto *ED = E->getType()->getAsEnumDecl())
+    return ED;
 
   // Otherwise, nope.
   return nullptr;

diff  --git a/clang/lib/Sema/SemaLookup.cpp b/clang/lib/Sema/SemaLookup.cpp
index 42519a52b57ac..86ffae9363beb 100644
--- a/clang/lib/Sema/SemaLookup.cpp
+++ b/clang/lib/Sema/SemaLookup.cpp
@@ -2770,10 +2770,7 @@ bool Sema::LookupInSuper(LookupResult &R, CXXRecordDecl *Class) {
   // members of Class itself.  That is, the naming class is Class, and the
   // access includes the access of the base.
   for (const auto &BaseSpec : Class->bases()) {
-    CXXRecordDecl *RD =
-        cast<CXXRecordDecl>(
-            BaseSpec.getType()->castAs<RecordType>()->getOriginalDecl())
-            ->getDefinitionOrSelf();
+    auto *RD = BaseSpec.getType()->castAsCXXRecordDecl();
     LookupResult Result(*this, R.getLookupNameInfo(), R.getLookupKind());
     Result.setBaseObjectType(Context.getCanonicalTagType(Class));
     LookupQualifiedName(Result, RD);
@@ -3207,8 +3204,7 @@ addAssociatedClassesAndNamespaces(AssociatedLookup &Result, QualType Ty) {
     //        member’s class; else it has no associated class.
     case Type::Enum: {
       // FIXME: This should use the original decl.
-      EnumDecl *Enum =
-          cast<EnumType>(T)->getOriginalDecl()->getDefinitionOrSelf();
+      auto *Enum = T->castAsEnumDecl();
 
       DeclContext *Ctx = Enum->getDeclContext();
       if (CXXRecordDecl *EnclosingClass = dyn_cast<CXXRecordDecl>(Ctx))

diff  --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index 7d800c446b595..8f666cee72937 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -18628,12 +18628,12 @@ buildDeclareReductionRef(Sema &SemaRef, SourceLocation Loc, SourceRange Range,
   //        the set of member candidates is empty.
   LookupResult Lookup(SemaRef, ReductionId, Sema::LookupOMPReductionName);
   Lookup.suppressDiagnostics();
-  if (const auto *TyRec = Ty->getAs<RecordType>()) {
+  if (Ty->isRecordType()) {
     // Complete the type if it can be completed.
     // If the type is neither complete nor being defined, bail out now.
     bool IsComplete = SemaRef.isCompleteType(Loc, Ty);
-    RecordDecl *RD = TyRec->getOriginalDecl()->getDefinition();
-    if (IsComplete || RD) {
+    auto *RD = Ty->castAsRecordDecl();
+    if (IsComplete || RD->isBeingDefined()) {
       Lookup.clear();
       SemaRef.LookupQualifiedName(Lookup, RD);
       if (Lookup.empty()) {

diff  --git a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp
index 5678804d87f1f..d24550060893b 100644
--- a/clang/lib/Sema/SemaOverload.cpp
+++ b/clang/lib/Sema/SemaOverload.cpp
@@ -369,8 +369,8 @@ NarrowingKind StandardConversionSequence::getNarrowingKind(
   // A conversion to an enumeration type is narrowing if the conversion to
   // the underlying type is narrowing. This only arises for expressions of
   // the form 'Enum{init}'.
-  if (auto *ET = ToType->getAs<EnumType>())
-    ToType = ET->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = ToType->getAsEnumDecl())
+    ToType = ED->getIntegerType();
 
   switch (Second) {
   // 'bool' is an integral type; dispatch to the right place to handle it.
@@ -2661,11 +2661,9 @@ bool Sema::IsIntegralPromotion(Expr *From, QualType FromType, QualType ToType) {
   //   integral promotion can be applied to its underlying type, a prvalue of an
   //   unscoped enumeration type whose underlying type is fixed can also be
   //   converted to a prvalue of the promoted underlying type.
-  if (const EnumType *FromEnumType = FromType->getAs<EnumType>()) {
+  if (const auto *FromED = FromType->getAsEnumDecl()) {
     // C++0x 7.2p9: Note that this implicit enum to int conversion is not
     // provided for a scoped enumeration.
-    const EnumDecl *FromED =
-        FromEnumType->getOriginalDecl()->getDefinitionOrSelf();
     if (FromED->isScoped())
       return false;
 
@@ -4508,12 +4506,10 @@ getFixedEnumPromtion(Sema &S, const StandardConversionSequence &SCS) {
   if (SCS.Second != ICK_Integral_Promotion)
     return FixedEnumPromotion::None;
 
-  QualType FromType = SCS.getFromType();
-  if (!FromType->isEnumeralType())
+  const auto *Enum = SCS.getFromType()->getAsEnumDecl();
+  if (!Enum)
     return FixedEnumPromotion::None;
 
-  EnumDecl *Enum =
-      FromType->castAs<EnumType>()->getOriginalDecl()->getDefinitionOrSelf();
   if (!Enum->isFixed())
     return FixedEnumPromotion::None;
 
@@ -5149,10 +5145,7 @@ FindConversionForRefInit(Sema &S, ImplicitConversionSequence &ICS,
                          Expr *Init, QualType T2, bool AllowRvalues,
                          bool AllowExplicit) {
   assert(T2->isRecordType() && "Can only find conversions of record types.");
-  auto *T2RecordDecl =
-      cast<CXXRecordDecl>(T2->castAs<RecordType>()->getOriginalDecl())
-          ->getDefinitionOrSelf();
-
+  auto *T2RecordDecl = T2->castAsCXXRecordDecl();
   OverloadCandidateSet CandidateSet(
       DeclLoc, OverloadCandidateSet::CSK_InitByUserDefinedConversion);
   const auto &Conversions = T2RecordDecl->getVisibleConversionFunctions();
@@ -8340,10 +8333,7 @@ void Sema::AddConversionCandidate(
   QualType ObjectType = From->getType();
   if (const auto *FromPtrType = ObjectType->getAs<PointerType>())
     ObjectType = FromPtrType->getPointeeType();
-  const auto *ConversionContext =
-      cast<CXXRecordDecl>(ObjectType->castAs<RecordType>()->getOriginalDecl())
-          ->getDefinitionOrSelf();
-
+  const auto *ConversionContext = ObjectType->castAsCXXRecordDecl();
   // C++23 [over.best.ics.general]
   // However, if the target is [...]
   // - the object parameter of a user-defined conversion function
@@ -9093,9 +9083,7 @@ BuiltinCandidateTypeSet::AddTypesConvertedFrom(QualType Ty,
     if (!SemaRef.isCompleteType(Loc, Ty))
       return;
 
-    auto *ClassDecl =
-        cast<CXXRecordDecl>(Ty->getAs<RecordType>()->getOriginalDecl())
-            ->getDefinitionOrSelf();
+    auto *ClassDecl = Ty->castAsCXXRecordDecl();
     for (NamedDecl *D : ClassDecl->getVisibleConversionFunctions()) {
       if (isa<UsingShadowDecl>(D))
         D = cast<UsingShadowDecl>(D)->getTargetDecl();
@@ -10154,7 +10142,7 @@ class BuiltinOperatorOverloadBuilder {
         continue;
       for (QualType MemPtrTy : CandidateTypes[1].member_pointer_types()) {
         const MemberPointerType *mptr = cast<MemberPointerType>(MemPtrTy);
-        CXXRecordDecl *D1 = C1->getAsCXXRecordDecl(),
+        CXXRecordDecl *D1 = C1->castAsCXXRecordDecl(),
                       *D2 = mptr->getMostRecentCXXRecordDecl();
         if (!declaresSameEntity(D1, D2) &&
             !S.IsDerivedFrom(CandidateSet.getLocation(), D1, D2))
@@ -16352,9 +16340,9 @@ Sema::BuildCallToObjectOfClassType(Scope *S, Expr *Obj,
                           diag::err_incomplete_object_call, Object.get()))
     return true;
 
-  const auto *Record = Object.get()->getType()->castAs<RecordType>();
+  auto *Record = Object.get()->getType()->castAsCXXRecordDecl();
   LookupResult R(*this, OpName, LParenLoc, LookupOrdinaryName);
-  LookupQualifiedName(R, Record->getOriginalDecl()->getDefinitionOrSelf());
+  LookupQualifiedName(R, Record);
   R.suppressAccessDiagnostics();
 
   for (LookupResult::iterator Oper = R.begin(), OperEnd = R.end();
@@ -16373,8 +16361,7 @@ Sema::BuildCallToObjectOfClassType(Scope *S, Expr *Obj,
   // we filter them out to produce better error diagnostics, ie to avoid
   // showing 2 failed overloads instead of one.
   bool IgnoreSurrogateFunctions = false;
-  if (CandidateSet.nonDeferredCandidatesCount() == 1 &&
-      Record->getAsCXXRecordDecl()->isLambda()) {
+  if (CandidateSet.nonDeferredCandidatesCount() == 1 && Record->isLambda()) {
     const OverloadCandidate &Candidate = *CandidateSet.begin();
     if (!Candidate.Viable &&
         Candidate.FailureKind == ovl_fail_constraints_not_satisfied)
@@ -16398,9 +16385,7 @@ Sema::BuildCallToObjectOfClassType(Scope *S, Expr *Obj,
   //   functions for each conversion function declared in an
   //   accessible base class provided the function is not hidden
   //   within T by another intervening declaration.
-  const auto &Conversions = cast<CXXRecordDecl>(Record->getOriginalDecl())
-                                ->getDefinitionOrSelf()
-                                ->getVisibleConversionFunctions();
+  const auto &Conversions = Record->getVisibleConversionFunctions();
   for (auto I = Conversions.begin(), E = Conversions.end();
        !IgnoreSurrogateFunctions && I != E; ++I) {
     NamedDecl *D = *I;
@@ -16621,10 +16606,7 @@ ExprResult Sema::BuildOverloadedArrowExpr(Scope *S, Expr *Base,
     return ExprError();
 
   LookupResult R(*this, OpName, OpLoc, LookupOrdinaryName);
-  LookupQualifiedName(R, Base->getType()
-                             ->castAs<RecordType>()
-                             ->getOriginalDecl()
-                             ->getDefinitionOrSelf());
+  LookupQualifiedName(R, Base->getType()->castAsRecordDecl());
   R.suppressAccessDiagnostics();
 
   for (LookupResult::iterator Oper = R.begin(), OperEnd = R.end();

diff  --git a/clang/lib/Sema/SemaPPC.cpp b/clang/lib/Sema/SemaPPC.cpp
index 46d7372dd056b..bfa458d207b46 100644
--- a/clang/lib/Sema/SemaPPC.cpp
+++ b/clang/lib/Sema/SemaPPC.cpp
@@ -41,10 +41,7 @@ void SemaPPC::checkAIXMemberAlignment(SourceLocation Loc, const Expr *Arg) {
     return;
 
   QualType ArgType = Arg->getType();
-  for (const FieldDecl *FD : ArgType->castAs<RecordType>()
-                                 ->getOriginalDecl()
-                                 ->getDefinitionOrSelf()
-                                 ->fields()) {
+  for (const FieldDecl *FD : ArgType->castAsRecordDecl()->fields()) {
     if (const auto *AA = FD->getAttr<AlignedAttr>()) {
       CharUnits Alignment = getASTContext().toCharUnitsFromBits(
           AA->getAlignment(getASTContext()));

diff  --git a/clang/lib/Sema/SemaStmt.cpp b/clang/lib/Sema/SemaStmt.cpp
index bc1ddb80961a2..dda4e00119212 100644
--- a/clang/lib/Sema/SemaStmt.cpp
+++ b/clang/lib/Sema/SemaStmt.cpp
@@ -1590,12 +1590,10 @@ Sema::ActOnFinishSwitchStmt(SourceLocation SwitchLoc, Stmt *Switch,
     // we still do the analysis to preserve this information in the AST
     // (which can be used by flow-based analyes).
     //
-    const EnumType *ET = CondTypeBeforePromotion->getAs<EnumType>();
-
     // If switch has default case, then ignore it.
     if (!CaseListIsErroneous && !CaseListIsIncomplete && !HasConstantCond &&
-        ET) {
-      const EnumDecl *ED = ET->getOriginalDecl()->getDefinitionOrSelf();
+        CondTypeBeforePromotion->isEnumeralType()) {
+      const auto *ED = CondTypeBeforePromotion->castAsEnumDecl();
       if (!ED->isCompleteDefinition() || ED->enumerators().empty())
         goto enum_out;
 
@@ -1730,8 +1728,7 @@ void
 Sema::DiagnoseAssignmentEnum(QualType DstType, QualType SrcType,
                              Expr *SrcExpr) {
 
-  const auto *ET = DstType->getAs<EnumType>();
-  if (!ET)
+  if (!DstType->isEnumeralType())
     return;
 
   if (!SrcType->isIntegerType() ||
@@ -1741,7 +1738,7 @@ Sema::DiagnoseAssignmentEnum(QualType DstType, QualType SrcType,
   if (SrcExpr->isTypeDependent() || SrcExpr->isValueDependent())
     return;
 
-  const EnumDecl *ED = ET->getOriginalDecl()->getDefinitionOrSelf();
+  const auto *ED = DstType->castAsEnumDecl();
   if (!ED->isClosed())
     return;
 

diff  --git a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp
index ff273ab6fd2f5..e7d981a3210d1 100644
--- a/clang/lib/Sema/SemaTemplate.cpp
+++ b/clang/lib/Sema/SemaTemplate.cpp
@@ -7411,9 +7411,8 @@ ExprResult Sema::CheckTemplateArgument(NamedDecl *Param, QualType ParamType,
       // always a no-op, except when the parameter type is bool. In
       // that case, this may extend the argument from 1 bit to 8 bits.
       QualType IntegerType = ParamType;
-      if (const EnumType *Enum = IntegerType->getAs<EnumType>())
-        IntegerType =
-            Enum->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+      if (const auto *ED = IntegerType->getAsEnumDecl())
+        IntegerType = ED->getIntegerType();
       Value = Value.extOrTrunc(IntegerType->isBitIntType()
                                    ? Context.getIntWidth(IntegerType)
                                    : Context.getTypeSize(IntegerType));
@@ -7510,9 +7509,8 @@ ExprResult Sema::CheckTemplateArgument(NamedDecl *Param, QualType ParamType,
     }
 
     QualType IntegerType = ParamType;
-    if (const EnumType *Enum = IntegerType->getAs<EnumType>()) {
-      IntegerType =
-          Enum->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+    if (const auto *ED = IntegerType->getAsEnumDecl()) {
+      IntegerType = ED->getIntegerType();
     }
 
     if (ParamType->isBooleanType()) {
@@ -8028,8 +8026,8 @@ static Expr *BuildExpressionFromIntegralTemplateArgumentValue(
   // any integral type with C++11 enum classes, make sure we create the right
   // type of literal for it.
   QualType T = OrigT;
-  if (const EnumType *ET = OrigT->getAs<EnumType>())
-    T = ET->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = OrigT->getAsEnumDecl())
+    T = ED->getIntegerType();
 
   Expr *E;
   if (T->isAnyCharacterType()) {

diff  --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index 015fb0577df39..3f31a05d382a6 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -9617,19 +9617,16 @@ bool Sema::RequireLiteralType(SourceLocation Loc, QualType T,
   if (T->isVariableArrayType())
     return true;
 
-  const RecordType *RT = ElemType->getAs<RecordType>();
-  if (!RT)
+  if (!ElemType->isRecordType())
     return true;
 
-  const CXXRecordDecl *RD =
-      cast<CXXRecordDecl>(RT->getOriginalDecl())->getDefinitionOrSelf();
-
   // A partially-defined class type can't be a literal type, because a literal
   // class type must have a trivial destructor (which can't be checked until
   // the class definition is complete).
   if (RequireCompleteType(Loc, ElemType, diag::note_non_literal_incomplete, T))
     return true;
 
+  const auto *RD = ElemType->castAsCXXRecordDecl();
   // [expr.prim.lambda]p3:
   //   This class type is [not] a literal type.
   if (RD->isLambda() && !getLangOpts().CPlusPlus17) {

diff  --git a/clang/lib/StaticAnalyzer/Checkers/EnumCastOutOfRangeChecker.cpp b/clang/lib/StaticAnalyzer/Checkers/EnumCastOutOfRangeChecker.cpp
index 054b2e96bd13b..76a1470aaac44 100644
--- a/clang/lib/StaticAnalyzer/Checkers/EnumCastOutOfRangeChecker.cpp
+++ b/clang/lib/StaticAnalyzer/Checkers/EnumCastOutOfRangeChecker.cpp
@@ -139,18 +139,11 @@ void EnumCastOutOfRangeChecker::checkPreStmt(const CastExpr *CE,
   if (!ValueToCast)
     return;
 
-  const QualType T = CE->getType();
   // Check whether the cast type is an enum.
-  if (!T->isEnumeralType())
+  const auto *ED = CE->getType()->getAsEnumDecl();
+  if (!ED)
     return;
 
-  // If the cast is an enum, get its declaration.
-  // If the isEnumeralType() returned true, then the declaration must exist
-  // even if it is a stub declaration. It is up to the getDeclValuesForEnum()
-  // function to handle this.
-  const EnumDecl *ED =
-      T->castAs<EnumType>()->getOriginalDecl()->getDefinitionOrSelf();
-
   // [[clang::flag_enum]] annotated enums are by definition should be ignored.
   if (ED->hasAttr<FlagEnumAttr>())
     return;

diff  --git a/clang/lib/StaticAnalyzer/Core/RegionStore.cpp b/clang/lib/StaticAnalyzer/Core/RegionStore.cpp
index 02375b0c3469a..6a82aa7ba3936 100644
--- a/clang/lib/StaticAnalyzer/Core/RegionStore.cpp
+++ b/clang/lib/StaticAnalyzer/Core/RegionStore.cpp
@@ -2844,9 +2844,7 @@ RegionStoreManager::bindStruct(LimitedRegionBindingsConstRef B,
   QualType T = R->getValueType();
   assert(T->isStructureOrClassType());
 
-  const RecordType* RT = T->castAs<RecordType>();
-  const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
-
+  const auto *RD = T->castAsRecordDecl();
   if (!RD->isCompleteDefinition())
     return B;
 


        


More information about the cfe-commits mailing list