[clang] [clang-tools-extra] [clang] AST: fix getAs canonicalization of leaf types (PR #155028)

Matheus Izvekov via cfe-commits cfe-commits at lists.llvm.org
Fri Aug 22 14:18:28 PDT 2025


https://github.com/mizvekov updated https://github.com/llvm/llvm-project/pull/155028

>From e8a2223a0e8ba6cfc6907b2f525bcde25951d30d Mon Sep 17 00:00:00 2001
From: Matheus Izvekov <mizvekov at gmail.com>
Date: Wed, 20 Aug 2025 00:23:48 -0300
Subject: [PATCH] [clang] AST: fix getAs canonicalization of leaf types

Before this patch, the application of getAs and castAs on a leaf type
would always produce a canonical type, which is undesirable because
some of these types can be sugared.

The user expectation is that getAs only removes top level sugar nodes, leaving
all the type sugar on the returned node, but it had an optimization intended
for type nodes with no sugar: for these, we can skip the expensive traversal
of the top level sugar with a simple canonicalization followed by dyn_cast.

The problem is that the concept of leaf type does not map well to what is
correct to apply this optimization to.

This patch replaces the concept of leaf types with 'always canonical' types,
and only applies the canonicalization strategy on them.

In order to avoid the performance regression this would cause, as most current
users do not care about type sugar, this patch also replaces all of these uses
with alternative cast functions which operate through canonicalization.

* Introduces castAs variants to complement the getAsTagDecl and derived
  variants.
* Introduces getAsEnumDecl and castAsEnumDecl, complementing the current set, so
  that all TagDecls are covered.
* Introduces getAsCanonical and castAsCanonical, for faster casting when only
  the canonical type is desired.

The getAsTagDecl and related functions are not provided inline, because
of the circular dependencies that would involve. So this patch causes a
small overall performance regression:
<img width="1461" height="18" alt="image" src="https://github.com/user-attachments/assets/061dfb14-9506-4623-91ec-0f02f585d1dd" />

This will be fixed in a later patch, bringing the whole thing back to a positive
performance improvement overall:
<img width="1462" height="18" alt="image" src="https://github.com/user-attachments/assets/c237e68f-f696-44f4-acc6-a7c7ba5b0976" />
---
 clang-tools-extra/clang-doc/Serialize.cpp     |   5 +-
 .../EasilySwappableParametersCheck.cpp        |   2 +-
 .../bugprone/TaggedUnionMemberCountCheck.cpp  |  11 +-
 .../cert/DefaultOperatorNewAlignmentCheck.cpp |   2 +-
 .../MissingStdForwardCheck.cpp                |   2 +-
 .../ProTypeMemberInitCheck.cpp                |   2 +-
 .../cppcoreguidelines/SlicingCheck.cpp        |   5 +-
 .../fuchsia/MultipleInheritanceCheck.cpp      |  23 +-
 .../clang-tidy/misc/UnusedUsingDeclsCheck.cpp |   2 +-
 .../modernize/UseScopedLockCheck.cpp          |   2 +-
 .../SuspiciousCallArgumentCheck.cpp           |   7 +-
 .../utils/ExceptionSpecAnalyzer.cpp           |   5 +-
 .../utils/FormatStringConverter.cpp           |   7 +-
 .../clang-tidy/utils/TypeTraits.cpp           |   5 +-
 clang-tools-extra/clangd/Hover.cpp            |   5 +-
 .../clangd/refactor/tweaks/PopulateSwitch.cpp |   6 +-
 clang/include/clang/AST/AbstractBasicReader.h |   2 +-
 clang/include/clang/AST/AbstractBasicWriter.h |   2 +-
 clang/include/clang/AST/Decl.h                |   4 +
 clang/include/clang/AST/DeclCXX.h             |   4 +-
 clang/include/clang/AST/Type.h                |  50 +++-
 clang/include/clang/Basic/TypeNodes.td        |  27 +--
 clang/include/clang/Sema/Sema.h               |   2 +-
 clang/lib/AST/APValue.cpp                     |   3 +-
 clang/lib/AST/ASTContext.cpp                  | 100 ++++----
 clang/lib/AST/ASTStructuralEquivalence.cpp    |   4 +-
 clang/lib/AST/ByteCode/Compiler.cpp           |   7 +-
 clang/lib/AST/ByteCode/Context.cpp            |   3 +-
 clang/lib/AST/ByteCode/InterpBuiltin.cpp      |  21 +-
 clang/lib/AST/ByteCode/Pointer.cpp            |   2 +-
 clang/lib/AST/ByteCode/Program.cpp            |  16 +-
 clang/lib/AST/ByteCode/Record.cpp             |   4 +-
 clang/lib/AST/CXXInheritance.cpp              |  38 +--
 clang/lib/AST/Decl.cpp                        |  30 +--
 clang/lib/AST/DeclCXX.cpp                     |  44 ++--
 clang/lib/AST/DeclTemplate.cpp                |   6 +-
 clang/lib/AST/DeclarationName.cpp             |   4 +-
 clang/lib/AST/Expr.cpp                        |  17 +-
 clang/lib/AST/ExprConstant.cpp                |  58 ++---
 clang/lib/AST/FormatString.cpp                |  10 +-
 clang/lib/AST/InheritViz.cpp                  |   4 +-
 clang/lib/AST/ItaniumCXXABI.cpp               |   3 +-
 clang/lib/AST/ItaniumMangle.cpp               |  17 +-
 clang/lib/AST/JSONNodeDumper.cpp              |   2 +-
 clang/lib/AST/PrintfFormatString.cpp          |   4 +-
 clang/lib/AST/RecordLayoutBuilder.cpp         |  33 ++-
 clang/lib/AST/ScanfFormatString.cpp           |   3 +-
 clang/lib/AST/TemplateBase.cpp                |   4 +-
 clang/lib/AST/TextNodeDumper.cpp              |   2 +-
 clang/lib/AST/Type.cpp                        | 200 ++++++++--------
 clang/lib/AST/TypePrinter.cpp                 |   2 +-
 clang/lib/AST/VTTBuilder.cpp                  |  16 +-
 clang/lib/Analysis/ExprMutationAnalyzer.cpp   |   3 +-
 clang/lib/Analysis/ThreadSafetyCommon.cpp     |  12 +-
 clang/lib/CIR/CodeGen/CIRGenCall.cpp          |  11 +-
 clang/lib/CIR/CodeGen/CIRGenClass.cpp         |  12 +-
 clang/lib/CIR/CodeGen/CIRGenExpr.cpp          |  20 +-
 clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp |   5 +-
 clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp  |   6 +-
 clang/lib/CIR/CodeGen/CIRGenFunction.cpp      |  19 +-
 clang/lib/CIR/CodeGen/CIRGenModule.cpp        |   8 +-
 clang/lib/CIR/CodeGen/CIRGenOpenACCClause.cpp |   5 +-
 clang/lib/CIR/CodeGen/CIRGenTypes.cpp         |  17 +-
 clang/lib/CIR/CodeGen/TargetInfo.cpp          |   6 +-
 clang/lib/CodeGen/ABIInfo.cpp                 |   3 +-
 clang/lib/CodeGen/ABIInfoImpl.cpp             |  45 ++--
 clang/lib/CodeGen/CGBlocks.cpp                |   9 +-
 clang/lib/CodeGen/CGBuiltin.cpp               |   9 +-
 clang/lib/CodeGen/CGCUDANV.cpp                |   3 +-
 clang/lib/CodeGen/CGCXX.cpp                   |  16 +-
 clang/lib/CodeGen/CGCXXABI.cpp                |   5 +-
 clang/lib/CodeGen/CGCall.cpp                  |  27 +--
 clang/lib/CodeGen/CGClass.cpp                 |  60 +----
 clang/lib/CodeGen/CGDebugInfo.cpp             |   5 +-
 clang/lib/CodeGen/CGDecl.cpp                  |  12 +-
 clang/lib/CodeGen/CGExpr.cpp                  |  32 +--
 clang/lib/CodeGen/CGExprAgg.cpp               |  34 +--
 clang/lib/CodeGen/CGExprCXX.cpp               |  21 +-
 clang/lib/CodeGen/CGExprConstant.cpp          |  36 +--
 clang/lib/CodeGen/CGExprScalar.cpp            |  10 +-
 clang/lib/CodeGen/CGNonTrivialStruct.cpp      |   6 +-
 clang/lib/CodeGen/CGObjC.cpp                  |   6 +-
 clang/lib/CodeGen/CGObjCMac.cpp               |  17 +-
 clang/lib/CodeGen/CGObjCRuntime.cpp           |   6 +-
 clang/lib/CodeGen/CGOpenMPRuntime.cpp         |  29 +--
 clang/lib/CodeGen/CodeGenFunction.cpp         |  11 +-
 clang/lib/CodeGen/CodeGenFunction.h           |   4 +-
 clang/lib/CodeGen/CodeGenModule.cpp           |   6 +-
 clang/lib/CodeGen/CodeGenTBAA.cpp             |   9 +-
 clang/lib/CodeGen/CodeGenTypes.cpp            |  20 +-
 clang/lib/CodeGen/HLSLBufferLayoutBuilder.cpp |  28 +--
 clang/lib/CodeGen/ItaniumCXXABI.cpp           |  54 ++---
 clang/lib/CodeGen/SwiftCallingConv.cpp        |   2 +-
 clang/lib/CodeGen/Targets/AArch64.cpp         |  16 +-
 clang/lib/CodeGen/Targets/AMDGPU.cpp          |  19 +-
 clang/lib/CodeGen/Targets/ARC.cpp             |   6 +-
 clang/lib/CodeGen/Targets/ARM.cpp             |  16 +-
 clang/lib/CodeGen/Targets/BPF.cpp             |   9 +-
 clang/lib/CodeGen/Targets/CSKY.cpp            |   4 +-
 clang/lib/CodeGen/Targets/Hexagon.cpp         |   9 +-
 clang/lib/CodeGen/Targets/Lanai.cpp           |   6 +-
 clang/lib/CodeGen/Targets/LoongArch.cpp       |  13 +-
 clang/lib/CodeGen/Targets/Mips.cpp            |  12 +-
 clang/lib/CodeGen/Targets/NVPTX.cpp           |  13 +-
 clang/lib/CodeGen/Targets/PPC.cpp             |  13 +-
 clang/lib/CodeGen/Targets/RISCV.cpp           |  13 +-
 clang/lib/CodeGen/Targets/SPIR.cpp            |  19 +-
 clang/lib/CodeGen/Targets/Sparc.cpp           |   4 +-
 clang/lib/CodeGen/Targets/SystemZ.cpp         |  16 +-
 clang/lib/CodeGen/Targets/WebAssembly.cpp     |   6 +-
 clang/lib/CodeGen/Targets/X86.cpp             |  77 +++---
 clang/lib/CodeGen/Targets/XCore.cpp           |   2 +-
 clang/lib/Edit/RewriteObjCFoundationAPI.cpp   |   2 +-
 .../Frontend/Rewrite/RewriteModernObjC.cpp    |  61 ++---
 clang/lib/Frontend/Rewrite/RewriteObjC.cpp    |   5 +-
 clang/lib/Index/IndexTypeSourceInfo.cpp       |   2 +-
 clang/lib/Index/USRGeneration.cpp             |   3 +-
 .../Interpreter/InterpreterValuePrinter.cpp   |  12 +-
 clang/lib/Interpreter/Value.cpp               |  19 +-
 clang/lib/Sema/Sema.cpp                       |  33 ++-
 clang/lib/Sema/SemaAccess.cpp                 |  11 +-
 clang/lib/Sema/SemaBPF.cpp                    |  14 +-
 clang/lib/Sema/SemaCUDA.cpp                   |  17 +-
 clang/lib/Sema/SemaCXXScopeSpec.cpp           |   7 +-
 clang/lib/Sema/SemaCast.cpp                   |  30 ++-
 clang/lib/Sema/SemaChecking.cpp               |  96 +++-----
 clang/lib/Sema/SemaCodeComplete.cpp           |  13 +-
 clang/lib/Sema/SemaDecl.cpp                   | 165 ++++++-------
 clang/lib/Sema/SemaDeclAttr.cpp               |  56 ++---
 clang/lib/Sema/SemaDeclCXX.cpp                | 103 +++-----
 clang/lib/Sema/SemaDeclObjC.cpp               |  27 +--
 clang/lib/Sema/SemaExceptionSpec.cpp          |   5 +-
 clang/lib/Sema/SemaExpr.cpp                   |  70 +++---
 clang/lib/Sema/SemaExprCXX.cpp                |  40 ++--
 clang/lib/Sema/SemaExprObjC.cpp               |   5 +-
 clang/lib/Sema/SemaHLSL.cpp                   |  64 +++--
 clang/lib/Sema/SemaInit.cpp                   | 226 +++++++-----------
 clang/lib/Sema/SemaLambda.cpp                 |   5 +-
 clang/lib/Sema/SemaLookup.cpp                 |  19 +-
 clang/lib/Sema/SemaObjC.cpp                   |   4 +-
 clang/lib/Sema/SemaObjCProperty.cpp           |   6 +-
 clang/lib/Sema/SemaOpenMP.cpp                 |   6 +-
 clang/lib/Sema/SemaOverload.cpp               |  75 +++---
 clang/lib/Sema/SemaPPC.cpp                    |   5 +-
 clang/lib/Sema/SemaStmt.cpp                   |  15 +-
 clang/lib/Sema/SemaStmtAsm.cpp                |  16 +-
 clang/lib/Sema/SemaSwift.cpp                  |  16 +-
 clang/lib/Sema/SemaTemplate.cpp               |  27 +--
 clang/lib/Sema/SemaTemplateDeduction.cpp      |  18 +-
 clang/lib/Sema/SemaTemplateDeductionGuide.cpp |   6 +-
 .../lib/Sema/SemaTemplateInstantiateDecl.cpp  |  14 +-
 clang/lib/Sema/SemaType.cpp                   |  18 +-
 clang/lib/Sema/SemaTypeTraits.cpp             |  20 +-
 clang/lib/Sema/TreeTransform.h                |  42 ++--
 clang/lib/Sema/UsedDeclVisitor.h              |  10 +-
 .../Checkers/CastSizeChecker.cpp              |   5 +-
 .../Checkers/EnumCastOutOfRangeChecker.cpp    |  11 +-
 .../Checkers/LLVMConventionsChecker.cpp       |  18 +-
 .../Checkers/PaddingChecker.cpp               |   6 +-
 .../Checkers/WebKit/PtrTypesSemantics.cpp     |   2 +-
 clang/lib/StaticAnalyzer/Core/ExprEngine.cpp  |   2 +-
 clang/lib/StaticAnalyzer/Core/RegionStore.cpp |   6 +-
 clang/test/CXX/drs/cwg0xx.cpp                 |   4 +-
 clang/test/CXX/drs/cwg3xx.cpp                 |   2 +-
 .../class_template_param_inheritance.cpp      |   2 +-
 clang/tools/libclang/CIndex.cpp               |   2 +-
 clang/unittests/AST/ASTImporterTest.cpp       |   4 +-
 clang/unittests/AST/RandstructTest.cpp        |   3 +-
 clang/utils/TableGen/ASTTableGen.h            |   2 +-
 .../utils/TableGen/ClangTypeNodesEmitter.cpp  |  24 +-
 170 files changed, 1289 insertions(+), 1973 deletions(-)

diff --git a/clang-tools-extra/clang-doc/Serialize.cpp b/clang-tools-extra/clang-doc/Serialize.cpp
index bcab4f1b8a729..506cc0a3b0e8f 100644
--- a/clang-tools-extra/clang-doc/Serialize.cpp
+++ b/clang-tools-extra/clang-doc/Serialize.cpp
@@ -901,9 +901,8 @@ parseBases(RecordInfo &I, const CXXRecordDecl *D, bool IsFileInRootDir,
   if (!D->isThisDeclarationADefinition())
     return;
   for (const CXXBaseSpecifier &B : D->bases()) {
-    if (const RecordType *Ty = B.getType()->getAs<RecordType>()) {
-      if (const CXXRecordDecl *Base = cast_or_null<CXXRecordDecl>(
-              Ty->getOriginalDecl()->getDefinition())) {
+    if (const auto *Base = B.getType()->getAsCXXRecordDecl()) {
+      if (Base->isCompleteDefinition()) {
         // Initialized without USR and name, this will be set in the following
         // if-else stmt.
         BaseRecordInfo BI(
diff --git a/clang-tools-extra/clang-tidy/bugprone/EasilySwappableParametersCheck.cpp b/clang-tools-extra/clang-tidy/bugprone/EasilySwappableParametersCheck.cpp
index 7f1eeef8ea0fd..3c718f1ddbe95 100644
--- a/clang-tools-extra/clang-tidy/bugprone/EasilySwappableParametersCheck.cpp
+++ b/clang-tools-extra/clang-tidy/bugprone/EasilySwappableParametersCheck.cpp
@@ -997,7 +997,7 @@ approximateStandardConversionSequence(const TheCheck &Check, QualType From,
     WorkType = QualType{ToBuiltin, FastQualifiersToApply};
   }
 
-  const auto *FromEnum = WorkType->getAs<EnumType>();
+  const auto *FromEnum = WorkType->getAsCanonical<EnumType>();
   const auto *ToEnum = To->getAs<EnumType>();
   if (FromEnum && ToNumeric && FromEnum->isUnscopedEnumerationType()) {
     // Unscoped enumerations (or enumerations in C) convert to numerics.
diff --git a/clang-tools-extra/clang-tidy/bugprone/TaggedUnionMemberCountCheck.cpp b/clang-tools-extra/clang-tidy/bugprone/TaggedUnionMemberCountCheck.cpp
index ddbb14e3ac62b..02f4421efdbf4 100644
--- a/clang-tools-extra/clang-tidy/bugprone/TaggedUnionMemberCountCheck.cpp
+++ b/clang-tools-extra/clang-tidy/bugprone/TaggedUnionMemberCountCheck.cpp
@@ -169,15 +169,8 @@ void TaggedUnionMemberCountCheck::check(
   if (!Root || !UnionField || !TagField)
     return;
 
-  const auto *UnionDef =
-      UnionField->getType().getCanonicalType().getTypePtr()->getAsRecordDecl();
-  const auto *EnumDef = llvm::dyn_cast<EnumDecl>(
-      TagField->getType().getCanonicalType().getTypePtr()->getAsTagDecl());
-
-  assert(UnionDef && "UnionDef is missing!");
-  assert(EnumDef && "EnumDef is missing!");
-  if (!UnionDef || !EnumDef)
-    return;
+  const auto *UnionDef = UnionField->getType()->castAsRecordDecl();
+  const auto *EnumDef = TagField->getType()->castAsEnumDecl();
 
   const std::size_t UnionMemberCount = llvm::range_size(UnionDef->fields());
   auto [TagCount, CountingEnumConstantDecl] = getNumberOfEnumValues(EnumDef);
diff --git a/clang-tools-extra/clang-tidy/cert/DefaultOperatorNewAlignmentCheck.cpp b/clang-tools-extra/clang-tidy/cert/DefaultOperatorNewAlignmentCheck.cpp
index 2a0d0ada42b28..2c2248afb69e7 100644
--- a/clang-tools-extra/clang-tidy/cert/DefaultOperatorNewAlignmentCheck.cpp
+++ b/clang-tools-extra/clang-tidy/cert/DefaultOperatorNewAlignmentCheck.cpp
@@ -31,7 +31,7 @@ void DefaultOperatorNewAlignmentCheck::check(
     return;
   const TagDecl *D = T->getAsTagDecl();
   // Alignment can not be obtained for undefined type.
-  if (!D || !D->getDefinition() || !D->isCompleteDefinition())
+  if (!D || !D->isCompleteDefinition())
     return;
 
   ASTContext &Context = D->getASTContext();
diff --git a/clang-tools-extra/clang-tidy/cppcoreguidelines/MissingStdForwardCheck.cpp b/clang-tools-extra/clang-tidy/cppcoreguidelines/MissingStdForwardCheck.cpp
index 82d1cf13440bc..75da6de9b5f13 100644
--- a/clang-tools-extra/clang-tidy/cppcoreguidelines/MissingStdForwardCheck.cpp
+++ b/clang-tools-extra/clang-tidy/cppcoreguidelines/MissingStdForwardCheck.cpp
@@ -45,7 +45,7 @@ AST_MATCHER(ParmVarDecl, isTemplateTypeParameter) {
 
   QualType ParamType =
       Node.getType().getNonPackExpansionType()->getPointeeType();
-  const auto *TemplateType = ParamType->getAs<TemplateTypeParmType>();
+  const auto *TemplateType = ParamType->getAsCanonical<TemplateTypeParmType>();
   if (!TemplateType)
     return false;
 
diff --git a/clang-tools-extra/clang-tidy/cppcoreguidelines/ProTypeMemberInitCheck.cpp b/clang-tools-extra/clang-tidy/cppcoreguidelines/ProTypeMemberInitCheck.cpp
index 40607597297b5..e6e79f0f0342a 100644
--- a/clang-tools-extra/clang-tidy/cppcoreguidelines/ProTypeMemberInitCheck.cpp
+++ b/clang-tools-extra/clang-tidy/cppcoreguidelines/ProTypeMemberInitCheck.cpp
@@ -189,7 +189,7 @@ struct InitializerInsertion {
 
 // Convenience utility to get a RecordDecl from a QualType.
 const RecordDecl *getCanonicalRecordDecl(const QualType &Type) {
-  if (const auto *RT = Type.getCanonicalType()->getAs<RecordType>())
+  if (const auto *RT = Type->getAsCanonical<RecordType>())
     return RT->getOriginalDecl();
   return nullptr;
 }
diff --git a/clang-tools-extra/clang-tidy/cppcoreguidelines/SlicingCheck.cpp b/clang-tools-extra/clang-tidy/cppcoreguidelines/SlicingCheck.cpp
index 40fd15c08f0a1..6508bfd5ca808 100644
--- a/clang-tools-extra/clang-tidy/cppcoreguidelines/SlicingCheck.cpp
+++ b/clang-tools-extra/clang-tidy/cppcoreguidelines/SlicingCheck.cpp
@@ -90,9 +90,8 @@ void SlicingCheck::diagnoseSlicedOverriddenMethods(
   }
   // Recursively process bases.
   for (const auto &Base : DerivedDecl.bases()) {
-    if (const auto *BaseRecordType = Base.getType()->getAs<RecordType>()) {
-      if (const auto *BaseRecord = cast_or_null<CXXRecordDecl>(
-              BaseRecordType->getOriginalDecl()->getDefinition()))
+    if (const auto *BaseRecord = Base.getType()->getAsCXXRecordDecl()) {
+      if (BaseRecord->isCompleteDefinition())
         diagnoseSlicedOverriddenMethods(Call, *BaseRecord, BaseDecl);
     }
   }
diff --git a/clang-tools-extra/clang-tidy/fuchsia/MultipleInheritanceCheck.cpp b/clang-tools-extra/clang-tidy/fuchsia/MultipleInheritanceCheck.cpp
index 0302a5ad4957c..36b6007b58a51 100644
--- a/clang-tools-extra/clang-tidy/fuchsia/MultipleInheritanceCheck.cpp
+++ b/clang-tools-extra/clang-tidy/fuchsia/MultipleInheritanceCheck.cpp
@@ -71,13 +71,10 @@ bool MultipleInheritanceCheck::isInterface(const CXXRecordDecl *Node) {
   for (const auto &I : Node->bases()) {
     if (I.isVirtual())
       continue;
-    const auto *Ty = I.getType()->getAs<RecordType>();
-    if (!Ty)
+    const auto *Base = I.getType()->getAsCXXRecordDecl();
+    if (!Base)
       continue;
-    const RecordDecl *D = Ty->getOriginalDecl()->getDefinition();
-    if (!D)
-      continue;
-    const auto *Base = cast<CXXRecordDecl>(D);
+    assert(Base->isCompleteDefinition());
     if (!isInterface(Base)) {
       addNodeToInterfaceMap(Node, false);
       return false;
@@ -103,11 +100,10 @@ void MultipleInheritanceCheck::check(const MatchFinder::MatchResult &Result) {
     for (const auto &I : D->bases()) {
       if (I.isVirtual())
         continue;
-      const auto *Ty = I.getType()->getAs<RecordType>();
-      if (!Ty)
+      const auto *Base = I.getType()->getAsCXXRecordDecl();
+      if (!Base)
         continue;
-      const auto *Base =
-          cast<CXXRecordDecl>(Ty->getOriginalDecl()->getDefinition());
+      assert(Base->isCompleteDefinition());
       if (!isInterface(Base))
         NumConcrete++;
     }
@@ -115,11 +111,10 @@ void MultipleInheritanceCheck::check(const MatchFinder::MatchResult &Result) {
     // Check virtual bases to see if there is more than one concrete
     // non-virtual base.
     for (const auto &V : D->vbases()) {
-      const auto *Ty = V.getType()->getAs<RecordType>();
-      if (!Ty)
+      const auto *Base = V.getType()->getAsCXXRecordDecl();
+      if (!Base)
         continue;
-      const auto *Base =
-          cast<CXXRecordDecl>(Ty->getOriginalDecl()->getDefinition());
+      assert(Base->isCompleteDefinition());
       if (!isInterface(Base))
         NumConcrete++;
     }
diff --git a/clang-tools-extra/clang-tidy/misc/UnusedUsingDeclsCheck.cpp b/clang-tools-extra/clang-tidy/misc/UnusedUsingDeclsCheck.cpp
index 8211a0ec6a5e1..49432073ce1d7 100644
--- a/clang-tools-extra/clang-tidy/misc/UnusedUsingDeclsCheck.cpp
+++ b/clang-tools-extra/clang-tidy/misc/UnusedUsingDeclsCheck.cpp
@@ -131,7 +131,7 @@ void UnusedUsingDeclsCheck::check(const MatchFinder::MatchResult &Result) {
       return;
     }
     if (const auto *ECD = dyn_cast<EnumConstantDecl>(Used)) {
-      if (const auto *ET = ECD->getType()->getAs<EnumType>())
+      if (const auto *ET = ECD->getType()->getAsCanonical<EnumType>())
         removeFromFoundDecls(ET->getOriginalDecl());
     }
   };
diff --git a/clang-tools-extra/clang-tidy/modernize/UseScopedLockCheck.cpp b/clang-tools-extra/clang-tidy/modernize/UseScopedLockCheck.cpp
index 5310f2fd25381..c74db0ed861b4 100644
--- a/clang-tools-extra/clang-tidy/modernize/UseScopedLockCheck.cpp
+++ b/clang-tools-extra/clang-tidy/modernize/UseScopedLockCheck.cpp
@@ -28,7 +28,7 @@ static bool isLockGuardDecl(const NamedDecl *Decl) {
 }
 
 static bool isLockGuard(const QualType &Type) {
-  if (const auto *Record = Type->getAs<RecordType>())
+  if (const auto *Record = Type->getAsCanonical<RecordType>())
     if (const RecordDecl *Decl = Record->getOriginalDecl())
       return isLockGuardDecl(Decl);
 
diff --git a/clang-tools-extra/clang-tidy/readability/SuspiciousCallArgumentCheck.cpp b/clang-tools-extra/clang-tidy/readability/SuspiciousCallArgumentCheck.cpp
index 447c2437666cf..a80637dee18f4 100644
--- a/clang-tools-extra/clang-tidy/readability/SuspiciousCallArgumentCheck.cpp
+++ b/clang-tools-extra/clang-tidy/readability/SuspiciousCallArgumentCheck.cpp
@@ -413,10 +413,11 @@ static bool areTypesCompatible(QualType ArgType, QualType ParamType,
 
   // Arithmetic types are interconvertible, except scoped enums.
   if (ParamType->isArithmeticType() && ArgType->isArithmeticType()) {
-    if ((ParamType->isEnumeralType() &&
-         ParamType->castAs<EnumType>()->getOriginalDecl()->isScoped()) ||
+    if ((ParamType->isEnumeralType() && ParamType->castAsCanonical<EnumType>()
+                                            ->getOriginalDecl()
+                                            ->isScoped()) ||
         (ArgType->isEnumeralType() &&
-         ArgType->castAs<EnumType>()->getOriginalDecl()->isScoped()))
+         ArgType->castAsCanonical<EnumType>()->getOriginalDecl()->isScoped()))
       return false;
 
     return true;
diff --git a/clang-tools-extra/clang-tidy/utils/ExceptionSpecAnalyzer.cpp b/clang-tools-extra/clang-tidy/utils/ExceptionSpecAnalyzer.cpp
index aa6aefcf0c493..4314817e4f69d 100644
--- a/clang-tools-extra/clang-tidy/utils/ExceptionSpecAnalyzer.cpp
+++ b/clang-tools-extra/clang-tidy/utils/ExceptionSpecAnalyzer.cpp
@@ -66,10 +66,7 @@ ExceptionSpecAnalyzer::analyzeBase(const CXXBaseSpecifier &Base,
   if (!RecType)
     return State::Unknown;
 
-  const auto *BaseClass =
-      cast<CXXRecordDecl>(RecType->getOriginalDecl())->getDefinitionOrSelf();
-
-  return analyzeRecord(BaseClass, Kind);
+  return analyzeRecord(RecType->getAsCXXRecordDecl(), Kind);
 }
 
 ExceptionSpecAnalyzer::State
diff --git a/clang-tools-extra/clang-tidy/utils/FormatStringConverter.cpp b/clang-tools-extra/clang-tidy/utils/FormatStringConverter.cpp
index 0df8e913100fc..0d0834dc38fc6 100644
--- a/clang-tools-extra/clang-tidy/utils/FormatStringConverter.cpp
+++ b/clang-tools-extra/clang-tidy/utils/FormatStringConverter.cpp
@@ -460,10 +460,9 @@ bool FormatStringConverter::emitIntegerArgument(
     // be passed as its underlying type. However, printf will have forced
     // the signedness based on the format string, so we need to do the
     // same.
-    if (const auto *ET = ArgType->getAs<EnumType>()) {
-      if (const std::optional<std::string> MaybeCastType = castTypeForArgument(
-              ArgKind,
-              ET->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType()))
+    if (const auto *ED = ArgType->getAsEnumDecl()) {
+      if (const std::optional<std::string> MaybeCastType =
+              castTypeForArgument(ArgKind, ED->getIntegerType()))
         ArgFixes.emplace_back(
             ArgIndex, (Twine("static_cast<") + *MaybeCastType + ">(").str());
       else
diff --git a/clang-tools-extra/clang-tidy/utils/TypeTraits.cpp b/clang-tools-extra/clang-tidy/utils/TypeTraits.cpp
index 5518afd32e7f0..f944306171135 100644
--- a/clang-tools-extra/clang-tidy/utils/TypeTraits.cpp
+++ b/clang-tools-extra/clang-tidy/utils/TypeTraits.cpp
@@ -119,9 +119,8 @@ bool isTriviallyDefaultConstructible(QualType Type, const ASTContext &Context) {
   if (CanonicalType->isScalarType() || CanonicalType->isVectorType())
     return true;
 
-  if (const auto *RT = CanonicalType->getAs<RecordType>()) {
-    return recordIsTriviallyDefaultConstructible(
-        *RT->getOriginalDecl()->getDefinitionOrSelf(), Context);
+  if (const auto *RD = CanonicalType->getAsRecordDecl()) {
+    return recordIsTriviallyDefaultConstructible(*RD, Context);
   }
 
   // No other types can match.
diff --git a/clang-tools-extra/clangd/Hover.cpp b/clang-tools-extra/clangd/Hover.cpp
index a7cf45c632827..6e45c9ae4f8bf 100644
--- a/clang-tools-extra/clangd/Hover.cpp
+++ b/clang-tools-extra/clangd/Hover.cpp
@@ -454,8 +454,7 @@ std::optional<std::string> printExprValue(const Expr *E,
       Constant.Val.getInt().getSignificantBits() <= 64) {
     // Compare to int64_t to avoid bit-width match requirements.
     int64_t Val = Constant.Val.getInt().getExtValue();
-    for (const EnumConstantDecl *ECD :
-         T->castAs<EnumType>()->getOriginalDecl()->enumerators())
+    for (const EnumConstantDecl *ECD : T->castAsEnumDecl()->enumerators())
       if (ECD->getInitVal() == Val)
         return llvm::formatv("{0} ({1})", ECD->getNameAsString(),
                              printHex(Constant.Val.getInt()))
@@ -832,7 +831,7 @@ std::optional<HoverInfo> getThisExprHoverContents(const CXXThisExpr *CTE,
                                                   ASTContext &ASTCtx,
                                                   const PrintingPolicy &PP) {
   QualType OriginThisType = CTE->getType()->getPointeeType();
-  QualType ClassType = declaredType(OriginThisType->getAsTagDecl());
+  QualType ClassType = declaredType(OriginThisType->castAsTagDecl());
   // For partial specialization class, origin `this` pointee type will be
   // parsed as `InjectedClassNameType`, which will ouput template arguments
   // like "type-parameter-0-0". So we retrieve user written class type in this
diff --git a/clang-tools-extra/clangd/refactor/tweaks/PopulateSwitch.cpp b/clang-tools-extra/clangd/refactor/tweaks/PopulateSwitch.cpp
index 7e616968c6046..2c9841762b869 100644
--- a/clang-tools-extra/clangd/refactor/tweaks/PopulateSwitch.cpp
+++ b/clang-tools-extra/clangd/refactor/tweaks/PopulateSwitch.cpp
@@ -113,11 +113,11 @@ bool PopulateSwitch::prepare(const Selection &Sel) {
   // Ignore implicit casts, since enums implicitly cast to integer types.
   Cond = Cond->IgnoreParenImpCasts();
   // Get the canonical type to handle typedefs.
-  EnumT = Cond->getType().getCanonicalType()->getAsAdjusted<EnumType>();
+  EnumT = Cond->getType()->getAsCanonical<EnumType>();
   if (!EnumT)
     return false;
-  EnumD = EnumT->getOriginalDecl();
-  if (!EnumD || EnumD->isDependentType())
+  EnumD = EnumT->getOriginalDecl()->getDefinitionOrSelf();
+  if (EnumD->isDependentType())
     return false;
 
   // Finally, check which cases exist and which are covered.
diff --git a/clang/include/clang/AST/AbstractBasicReader.h b/clang/include/clang/AST/AbstractBasicReader.h
index 26052b8086cf7..0d187eb49d6ca 100644
--- a/clang/include/clang/AST/AbstractBasicReader.h
+++ b/clang/include/clang/AST/AbstractBasicReader.h
@@ -193,7 +193,7 @@ class DataStreamBasicReader : public BasicReaderBase<Impl> {
     auto elemTy = origTy;
     unsigned pathLength = asImpl().readUInt32();
     for (unsigned i = 0; i < pathLength; ++i) {
-      if (elemTy->template getAs<RecordType>()) {
+      if (elemTy->isRecordType()) {
         unsigned int_ = asImpl().readUInt32();
         Decl *decl = asImpl().template readDeclAs<Decl>();
         if (auto *recordDecl = dyn_cast<CXXRecordDecl>(decl))
diff --git a/clang/include/clang/AST/AbstractBasicWriter.h b/clang/include/clang/AST/AbstractBasicWriter.h
index d41e655986ef9..8ea0c29cf5070 100644
--- a/clang/include/clang/AST/AbstractBasicWriter.h
+++ b/clang/include/clang/AST/AbstractBasicWriter.h
@@ -176,7 +176,7 @@ class DataStreamBasicWriter : public BasicWriterBase<Impl> {
     asImpl().writeUInt32(path.size());
     auto &ctx = ((BasicWriterBase<Impl> *)this)->getASTContext();
     for (auto elem : path) {
-      if (elemTy->getAs<RecordType>()) {
+      if (elemTy->isRecordType()) {
         asImpl().writeUInt32(elem.getAsBaseOrMember().getInt());
         const Decl *baseOrMember = elem.getAsBaseOrMember().getPointer();
         if (const auto *recordDecl = dyn_cast<CXXRecordDecl>(baseOrMember)) {
diff --git a/clang/include/clang/AST/Decl.h b/clang/include/clang/AST/Decl.h
index bebbde3661a33..79636a67dafba 100644
--- a/clang/include/clang/AST/Decl.h
+++ b/clang/include/clang/AST/Decl.h
@@ -3915,6 +3915,10 @@ class TagDecl : public TypeDecl,
   bool isUnion() const { return getTagKind() == TagTypeKind::Union; }
   bool isEnum() const { return getTagKind() == TagTypeKind::Enum; }
 
+  bool isStructureOrClass() const {
+    return isStruct() || isClass() || isInterface();
+  }
+
   /// Is this tag type named, either directly or via being defined in
   /// a typedef of this type?
   ///
diff --git a/clang/include/clang/AST/DeclCXX.h b/clang/include/clang/AST/DeclCXX.h
index 1d2ef0f4f2319..3ee03122f50cf 100644
--- a/clang/include/clang/AST/DeclCXX.h
+++ b/clang/include/clang/AST/DeclCXX.h
@@ -3825,7 +3825,9 @@ class UsingEnumDecl : public BaseUsingDecl, public Mergeable<UsingEnumDecl> {
   void setEnumType(TypeSourceInfo *TSI) { EnumType = TSI; }
 
 public:
-  EnumDecl *getEnumDecl() const { return cast<EnumDecl>(EnumType->getType()->getAsTagDecl()); }
+  EnumDecl *getEnumDecl() const {
+    return cast<clang::EnumType>(EnumType->getType())->getOriginalDecl();
+  }
 
   static UsingEnumDecl *Create(ASTContext &C, DeclContext *DC,
                                SourceLocation UsingL, SourceLocation EnumL,
diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index adf5cb0462154..e4e2463d7f55a 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -2883,14 +2883,23 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
   /// because the type is a RecordType or because it is the injected-class-name
   /// type of a class template or class template partial specialization.
   CXXRecordDecl *getAsCXXRecordDecl() const;
+  CXXRecordDecl *castAsCXXRecordDecl() const;
 
   /// Retrieves the RecordDecl this type refers to.
   RecordDecl *getAsRecordDecl() const;
+  RecordDecl *castAsRecordDecl() const;
+
+  /// Retrieves the TagDecl that this type refers to, either
+  /// because the type is a TagType or because it is the injected-class-name
+  /// type of a class template or class template partial specialization.
+  EnumDecl *getAsEnumDecl() const;
+  EnumDecl *castAsEnumDecl() const;
 
   /// Retrieves the TagDecl that this type refers to, either
   /// because the type is a TagType or because it is the injected-class-name
   /// type of a class template or class template partial specialization.
   TagDecl *getAsTagDecl() const;
+  TagDecl *castAsTagDecl() const;
 
   /// If this is a pointer or reference to a RecordType, return the
   /// CXXRecordDecl that the type refers to.
@@ -2922,8 +2931,31 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
   ///
   /// There are some specializations of this member template listed
   /// immediately following this class.
+  ///
+  /// If you are interested only in the canonical properties of this type,
+  /// consider using getAsCanonical instead, as that is much faster.
   template <typename T> const T *getAs() const;
 
+  /// If this type is canonically the specified type, return its canonical type
+  /// cast to that specified type, otherwise returns null.
+  template <typename T> const T *getAsCanonical() const {
+    return dyn_cast<T>(CanonicalType);
+  }
+
+  /// Return this type's canonical type cast to the specified type.
+  /// If the type is not canonically that specified type, the behaviour is
+  /// undefined.
+  template <typename T> const T *castAsCanonical() const {
+    return cast<T>(CanonicalType);
+  }
+
+// It is not helpful to use these on types which are never canonical
+#define TYPE(Class, Base)
+#define NEVER_CANONICAL_TYPE(Class)                                            \
+  template <> inline const Class##Type *Type::getAsCanonical() const = delete; \
+  template <> inline const Class##Type *Type::castAsCanonical() const = delete;
+#include "clang/AST/TypeNodes.inc"
+
   /// Look through sugar for an instance of TemplateSpecializationType which
   /// is not a type alias, or null if there is no such type.
   /// This is used when you want as-written template arguments or the template
@@ -3135,16 +3167,16 @@ template <> const BoundsAttributedType *Type::getAs() const;
 /// sugar until it reaches an CountAttributedType or a non-sugared type.
 template <> const CountAttributedType *Type::getAs() const;
 
-// We can do canonical leaf types faster, because we don't have to
-// worry about preserving child type decoration.
+// We can do always canonical types faster, because we don't have to
+// worry about preserving decoration.
 #define TYPE(Class, Base)
-#define LEAF_TYPE(Class) \
-template <> inline const Class##Type *Type::getAs() const { \
-  return dyn_cast<Class##Type>(CanonicalType); \
-} \
-template <> inline const Class##Type *Type::castAs() const { \
-  return cast<Class##Type>(CanonicalType); \
-}
+#define ALWAYS_CANONICAL_TYPE(Class)                                           \
+  template <> inline const Class##Type *Type::getAs() const {                  \
+    return dyn_cast<Class##Type>(CanonicalType);                               \
+  }                                                                            \
+  template <> inline const Class##Type *Type::castAs() const {                 \
+    return cast<Class##Type>(CanonicalType);                                   \
+  }
 #include "clang/AST/TypeNodes.inc"
 
 /// This class is used for builtin types like 'int'.  Builtin
diff --git a/clang/include/clang/Basic/TypeNodes.td b/clang/include/clang/Basic/TypeNodes.td
index c8d45dec78816..fb6862b90987f 100644
--- a/clang/include/clang/Basic/TypeNodes.td
+++ b/clang/include/clang/Basic/TypeNodes.td
@@ -37,21 +37,12 @@ class NeverCanonical {}
 /// canonical types can ignore these nodes.
 class NeverCanonicalUnlessDependent {}
 
-/// A type node which never has component type structure.  Some code may be
-/// able to operate on leaf types faster than they can on non-leaf types.
-///
-/// For example, the function type `void (int)` is not a leaf type because it
-/// is structurally composed of component types (`void` and `int`).
-///
-/// A struct type is a leaf type because its field types are not part of its
-/// type-expression.
-///
-/// Nodes like `TypedefType` which are syntactically leaves but can desugar
-/// to types that may not be leaves should not declare this.
-class LeafType {}
+/// A type node which is always a canonical type, that is, types for which
+/// `T.getCanonicalType() == T` always holds.
+class AlwaysCanonical {}
 
 def Type : TypeNode<?, 1>;
-def BuiltinType : TypeNode<Type>, LeafType;
+def BuiltinType : TypeNode<Type>, AlwaysCanonical;
 def ComplexType : TypeNode<Type>;
 def PointerType : TypeNode<Type>;
 def BlockPointerType : TypeNode<Type>;
@@ -88,14 +79,14 @@ def TypeOfType : TypeNode<Type>, NeverCanonicalUnlessDependent;
 def DecltypeType : TypeNode<Type>, NeverCanonicalUnlessDependent;
 def UnaryTransformType : TypeNode<Type>, NeverCanonicalUnlessDependent;
 def TagType : TypeNode<Type, 1>;
-def RecordType : TypeNode<TagType>, LeafType;
-def EnumType : TypeNode<TagType>, LeafType;
-def InjectedClassNameType : TypeNode<TagType>, AlwaysDependent, LeafType;
+def RecordType : TypeNode<TagType>;
+def EnumType : TypeNode<TagType>;
+def InjectedClassNameType : TypeNode<TagType>, AlwaysDependent;
 def AttributedType : TypeNode<Type>, NeverCanonical;
 def BTFTagAttributedType : TypeNode<Type>, NeverCanonical;
 def HLSLAttributedResourceType : TypeNode<Type>;
 def HLSLInlineSpirvType : TypeNode<Type>;
-def TemplateTypeParmType : TypeNode<Type>, AlwaysDependent, LeafType;
+def TemplateTypeParmType : TypeNode<Type>, AlwaysDependent;
 def SubstTemplateTypeParmType : TypeNode<Type>, NeverCanonical;
 def SubstPackType : TypeNode<Type, 1>;
 def SubstTemplateTypeParmPackType : TypeNode<SubstPackType>, AlwaysDependent;
@@ -110,7 +101,7 @@ def PackExpansionType : TypeNode<Type>, AlwaysDependent;
 def PackIndexingType  : TypeNode<Type>, NeverCanonicalUnlessDependent;
 def ObjCTypeParamType : TypeNode<Type>, NeverCanonical;
 def ObjCObjectType : TypeNode<Type>;
-def ObjCInterfaceType : TypeNode<ObjCObjectType>, LeafType;
+def ObjCInterfaceType : TypeNode<ObjCObjectType>, AlwaysCanonical;
 def ObjCObjectPointerType : TypeNode<Type>;
 def BoundsAttributedType : TypeNode<Type, 1>;
 def CountAttributedType : TypeNode<BoundsAttributedType>, NeverCanonical;
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 12112eb1ad225..2c9333491e930 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -5404,7 +5404,7 @@ class Sema final : public SemaBase {
 
   /// FinalizeVarWithDestructor - Prepare for calling destructor on the
   /// constructed variable.
-  void FinalizeVarWithDestructor(VarDecl *VD, const RecordType *DeclInitType);
+  void FinalizeVarWithDestructor(VarDecl *VD, CXXRecordDecl *DeclInit);
 
   /// Helper class that collects exception specifications for
   /// implicitly-declared special member functions.
diff --git a/clang/lib/AST/APValue.cpp b/clang/lib/AST/APValue.cpp
index 2d62209bbc28c..7173c2a0e1a2a 100644
--- a/clang/lib/AST/APValue.cpp
+++ b/clang/lib/AST/APValue.cpp
@@ -903,8 +903,7 @@ void APValue::printPretty(raw_ostream &Out, const PrintingPolicy &Policy,
   case APValue::Struct: {
     Out << '{';
     bool First = true;
-    const RecordDecl *RD =
-        Ty->castAs<RecordType>()->getOriginalDecl()->getDefinitionOrSelf();
+    const auto *RD = Ty->castAsRecordDecl();
     if (unsigned N = getStructNumBases()) {
       const CXXRecordDecl *CD = cast<CXXRecordDecl>(RD);
       CXXRecordDecl::base_class_const_iterator BI = CD->bases_begin();
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index bbb957067c4c8..405ee15d65cb7 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -654,9 +654,9 @@ comments::FullComment *ASTContext::getCommentForDecl(
       // does not have one of its own.
       QualType QT = TD->getUnderlyingType();
       if (const auto *TT = QT->getAs<TagType>())
-        if (const Decl *TD = TT->getOriginalDecl())
-          if (comments::FullComment *FC = getCommentForDecl(TD, PP))
-            return cloneFullComment(FC, D);
+        if (comments::FullComment *FC =
+                getCommentForDecl(TT->getOriginalDecl(), PP))
+          return cloneFullComment(FC, D);
     }
     else if (const auto *IC = dyn_cast<ObjCInterfaceDecl>(D)) {
       while (IC->getSuperClass()) {
@@ -1933,12 +1933,9 @@ TypeInfoChars ASTContext::getTypeInfoDataSizeInChars(QualType T) const {
   // of a base-class subobject.  We decide whether that's possible
   // during class layout, so here we can just trust the layout results.
   if (getLangOpts().CPlusPlus) {
-    if (const auto *RT = T->getAs<RecordType>()) {
-      const auto *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
-      if (!RD->isInvalidDecl()) {
-        const ASTRecordLayout &layout = getASTRecordLayout(RD);
-        Info.Width = layout.getDataSize();
-      }
+    if (const auto *RD = T->getAsCXXRecordDecl(); RD && !RD->isInvalidDecl()) {
+      const ASTRecordLayout &layout = getASTRecordLayout(RD);
+      Info.Width = layout.getDataSize();
     }
   }
 
@@ -2004,8 +2001,7 @@ bool ASTContext::isPromotableIntegerType(QualType T) const {
 
   // Enumerated types are promotable to their compatible integer types
   // (C99 6.3.1.1) a.k.a. its underlying type (C++ [conv.prom]p2).
-  if (const auto *ET = T->getAs<EnumType>()) {
-    const EnumDecl *ED = ET->getOriginalDecl()->getDefinitionOrSelf();
+  if (const auto *ED = T->getAsEnumDecl()) {
     if (T->isDependentType() || ED->getPromotionType().isNull() ||
         ED->isScoped())
       return false;
@@ -2618,10 +2614,10 @@ unsigned ASTContext::getTypeUnadjustedAlign(const Type *T) const {
     return I->second;
 
   unsigned UnadjustedAlign;
-  if (const auto *RT = T->getAs<RecordType>()) {
+  if (const auto *RT = T->getAsCanonical<RecordType>()) {
     const ASTRecordLayout &Layout = getASTRecordLayout(RT->getOriginalDecl());
     UnadjustedAlign = toBits(Layout.getUnadjustedAlignment());
-  } else if (const auto *ObjCI = T->getAs<ObjCInterfaceType>()) {
+  } else if (const auto *ObjCI = T->getAsCanonical<ObjCInterfaceType>()) {
     const ASTRecordLayout &Layout = getASTObjCInterfaceLayout(ObjCI->getDecl());
     UnadjustedAlign = toBits(Layout.getUnadjustedAlignment());
   } else {
@@ -2694,9 +2690,7 @@ unsigned ASTContext::getPreferredTypeAlign(const Type *T) const {
   if (!Target->allowsLargerPreferedTypeAlignment())
     return ABIAlign;
 
-  if (const auto *RT = T->getAs<RecordType>()) {
-    const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
-
+  if (const auto *RD = T->getAsRecordDecl()) {
     // When used as part of a typedef, or together with a 'packed' attribute,
     // the 'aligned' attribute can be used to decrease alignment. Note that the
     // 'packed' case is already taken into consideration when computing the
@@ -2717,11 +2711,8 @@ unsigned ASTContext::getPreferredTypeAlign(const Type *T) const {
   // possible.
   if (const auto *CT = T->getAs<ComplexType>())
     T = CT->getElementType().getTypePtr();
-  if (const auto *ET = T->getAs<EnumType>())
-    T = ET->getOriginalDecl()
-            ->getDefinitionOrSelf()
-            ->getIntegerType()
-            .getTypePtr();
+  if (const auto *ED = T->getAsEnumDecl())
+    T = ED->getIntegerType().getTypePtr();
   if (T->isSpecificBuiltinType(BuiltinType::Double) ||
       T->isSpecificBuiltinType(BuiltinType::LongLong) ||
       T->isSpecificBuiltinType(BuiltinType::ULongLong) ||
@@ -2887,12 +2878,10 @@ structHasUniqueObjectRepresentations(const ASTContext &Context,
 static std::optional<int64_t>
 getSubobjectSizeInBits(const FieldDecl *Field, const ASTContext &Context,
                        bool CheckIfTriviallyCopyable) {
-  if (Field->getType()->isRecordType()) {
-    const RecordDecl *RD = Field->getType()->getAsRecordDecl();
-    if (!RD->isUnion())
-      return structHasUniqueObjectRepresentations(Context, RD,
-                                                  CheckIfTriviallyCopyable);
-  }
+  if (const auto *RD = Field->getType()->getAsRecordDecl();
+      RD && !RD->isUnion())
+    return structHasUniqueObjectRepresentations(Context, RD,
+                                                CheckIfTriviallyCopyable);
 
   // A _BitInt type may not be unique if it has padding bits
   // but if it is a bitfield the padding bits are not used.
@@ -3047,10 +3036,7 @@ bool ASTContext::hasUniqueObjectRepresentations(
   if (const auto *MPT = Ty->getAs<MemberPointerType>())
     return !ABI->getMemberPointerInfo(MPT).HasPadding;
 
-  if (Ty->isRecordType()) {
-    const RecordDecl *Record =
-        Ty->castAs<RecordType>()->getOriginalDecl()->getDefinitionOrSelf();
-
+  if (const auto *Record = Ty->getAsRecordDecl()) {
     if (Record->isInvalidDecl())
       return false;
 
@@ -3422,10 +3408,7 @@ static void encodeTypeForFunctionPointerAuth(const ASTContext &Ctx,
     //   type, or an unsigned integer type.
     //
     // So we have to treat enum types as integers.
-    QualType UnderlyingType = cast<EnumType>(T)
-                                  ->getOriginalDecl()
-                                  ->getDefinitionOrSelf()
-                                  ->getIntegerType();
+    QualType UnderlyingType = T->castAsEnumDecl()->getIntegerType();
     return encodeTypeForFunctionPointerAuth(
         Ctx, OS, UnderlyingType.isNull() ? Ctx.IntTy : UnderlyingType);
   }
@@ -3569,8 +3552,7 @@ static void encodeTypeForFunctionPointerAuth(const ASTContext &Ctx,
     llvm_unreachable("should never get here");
   }
   case Type::Record: {
-    const RecordDecl *RD =
-        T->castAs<RecordType>()->getOriginalDecl()->getDefinitionOrSelf();
+    const RecordDecl *RD = T->castAsCanonical<RecordType>()->getOriginalDecl();
     const IdentifierInfo *II = RD->getIdentifier();
 
     // In C++, an immediate typedef of an anonymous struct or union
@@ -8362,8 +8344,8 @@ QualType ASTContext::isPromotableBitField(Expr *E) const {
 QualType ASTContext::getPromotedIntegerType(QualType Promotable) const {
   assert(!Promotable.isNull());
   assert(isPromotableIntegerType(Promotable));
-  if (const auto *ET = Promotable->getAs<EnumType>())
-    return ET->getOriginalDecl()->getDefinitionOrSelf()->getPromotionType();
+  if (const auto *ED = Promotable->getAsEnumDecl())
+    return ED->getPromotionType();
 
   if (const auto *BT = Promotable->getAs<BuiltinType>()) {
     // C++ [conv.prom]: A prvalue of type char16_t, char32_t, or wchar_t
@@ -8582,10 +8564,9 @@ QualType ASTContext::getObjCSuperType() const {
 }
 
 void ASTContext::setCFConstantStringType(QualType T) {
-  const auto *TD = T->castAs<TypedefType>();
-  CFConstantStringTypeDecl = cast<TypedefDecl>(TD->getDecl());
-  const auto *TagType = TD->castAs<RecordType>();
-  CFConstantStringTagDecl = TagType->getOriginalDecl()->getDefinitionOrSelf();
+  const auto *TT = T->castAs<TypedefType>();
+  CFConstantStringTypeDecl = cast<TypedefDecl>(TT->getDecl());
+  CFConstantStringTagDecl = TT->castAsRecordDecl();
 }
 
 QualType ASTContext::getBlockDescriptorType() const {
@@ -9296,8 +9277,8 @@ static char getObjCEncodingForPrimitiveType(const ASTContext *C,
     llvm_unreachable("invalid BuiltinType::Kind value");
 }
 
-static char ObjCEncodingForEnumType(const ASTContext *C, const EnumType *ET) {
-  EnumDecl *Enum = ET->getOriginalDecl()->getDefinitionOrSelf();
+static char ObjCEncodingForEnumDecl(const ASTContext *C, const EnumDecl *ED) {
+  EnumDecl *Enum = ED->getDefinitionOrSelf();
 
   // The encoding of an non-fixed enum type is always 'i', regardless of size.
   if (!Enum->isFixed())
@@ -9340,8 +9321,8 @@ static void EncodeBitField(const ASTContext *Ctx, std::string& S,
 
     S += llvm::utostr(Offset);
 
-    if (const auto *ET = T->getAs<EnumType>())
-      S += ObjCEncodingForEnumType(Ctx, ET);
+    if (const auto *ET = T->getAsCanonical<EnumType>())
+      S += ObjCEncodingForEnumDecl(Ctx, ET->getOriginalDecl());
     else {
       const auto *BT = T->castAs<BuiltinType>();
       S += getObjCEncodingForPrimitiveType(Ctx, BT);
@@ -9398,7 +9379,7 @@ void ASTContext::getObjCEncodingForTypeImpl(QualType T, std::string &S,
     if (const auto *BT = dyn_cast<BuiltinType>(CT))
       S += getObjCEncodingForPrimitiveType(this, BT);
     else
-      S += ObjCEncodingForEnumType(this, cast<EnumType>(CT));
+      S += ObjCEncodingForEnumDecl(this, cast<EnumType>(CT)->getOriginalDecl());
     return;
 
   case Type::Complex:
@@ -9465,7 +9446,7 @@ void ASTContext::getObjCEncodingForTypeImpl(QualType T, std::string &S,
         S += '*';
         return;
       }
-    } else if (const auto *RTy = PointeeTy->getAs<RecordType>()) {
+    } else if (const auto *RTy = PointeeTy->getAsCanonical<RecordType>()) {
       const IdentifierInfo *II = RTy->getOriginalDecl()->getIdentifier();
       // GCC binary compat: Need to convert "struct objc_class *" to "#".
       if (II == &Idents.get("objc_class")) {
@@ -11672,9 +11653,8 @@ QualType ASTContext::mergeFunctionTypes(QualType lhs, QualType rhs,
 
       // Look at the converted type of enum types, since that is the type used
       // to pass enum values.
-      if (const auto *Enum = paramTy->getAs<EnumType>()) {
-        paramTy =
-            Enum->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+      if (const auto *ED = paramTy->getAsEnumDecl()) {
+        paramTy = ED->getIntegerType();
         if (paramTy.isNull())
           return {};
       }
@@ -11832,10 +11812,10 @@ QualType ASTContext::mergeTypes(QualType LHS, QualType RHS, bool OfBlockPointer,
   if (LHSClass != RHSClass) {
     // Note that we only have special rules for turning block enum
     // returns into block int returns, not vice-versa.
-    if (const auto *ETy = LHS->getAs<EnumType>()) {
+    if (const auto *ETy = LHS->getAsCanonical<EnumType>()) {
       return mergeEnumWithInteger(*this, ETy, RHS, false);
     }
-    if (const EnumType* ETy = RHS->getAs<EnumType>()) {
+    if (const EnumType *ETy = RHS->getAsCanonical<EnumType>()) {
       return mergeEnumWithInteger(*this, ETy, LHS, BlockReturnType);
     }
     // allow block pointer type to match an 'id' type.
@@ -12265,8 +12245,8 @@ QualType ASTContext::mergeObjCGCQualifiers(QualType LHS, QualType RHS) {
 //===----------------------------------------------------------------------===//
 
 unsigned ASTContext::getIntWidth(QualType T) const {
-  if (const auto *ET = T->getAs<EnumType>())
-    T = ET->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = T->getAsEnumDecl())
+    T = ED->getIntegerType();
   if (T->isBooleanType())
     return 1;
   if (const auto *EIT = T->getAs<BitIntType>())
@@ -12291,8 +12271,8 @@ QualType ASTContext::getCorrespondingUnsignedType(QualType T) const {
 
   // For enums, get the underlying integer type of the enum, and let the general
   // integer type signchanging code handle it.
-  if (const auto *ETy = T->getAs<EnumType>())
-    T = ETy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = T->getAsEnumDecl())
+    T = ED->getIntegerType();
 
   switch (T->castAs<BuiltinType>()->getKind()) {
   case BuiltinType::Char_U:
@@ -12365,8 +12345,8 @@ QualType ASTContext::getCorrespondingSignedType(QualType T) const {
 
   // For enums, get the underlying integer type of the enum, and let the general
   // integer type signchanging code handle it.
-  if (const auto *ETy = T->getAs<EnumType>())
-    T = ETy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = T->getAsEnumDecl())
+    T = ED->getIntegerType();
 
   switch (T->castAs<BuiltinType>()->getKind()) {
   case BuiltinType::Char_S:
diff --git a/clang/lib/AST/ASTStructuralEquivalence.cpp b/clang/lib/AST/ASTStructuralEquivalence.cpp
index 6d83de384ee10..d118f97b49942 100644
--- a/clang/lib/AST/ASTStructuralEquivalence.cpp
+++ b/clang/lib/AST/ASTStructuralEquivalence.cpp
@@ -878,10 +878,10 @@ static bool IsStructurallyEquivalent(StructuralEquivalenceContext &Context,
       // Treat the enumeration as its underlying type and use the builtin type
       // class comparison.
       if (T1->getTypeClass() == Type::Enum) {
-        T1 = T1->getAs<EnumType>()->getOriginalDecl()->getIntegerType();
+        T1 = dyn_cast<EnumType>(T1)->getOriginalDecl()->getIntegerType();
         assert(T2->isBuiltinType() && !T1.isNull()); // Sanity check
       } else if (T2->getTypeClass() == Type::Enum) {
-        T2 = T2->getAs<EnumType>()->getOriginalDecl()->getIntegerType();
+        T2 = dyn_cast<EnumType>(T2)->getOriginalDecl()->getIntegerType();
         assert(T1->isBuiltinType() && !T2.isNull()); // Sanity check
       }
       TC = Type::Builtin;
diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp
index e3235d34e230e..aee7421066180 100644
--- a/clang/lib/AST/ByteCode/Compiler.cpp
+++ b/clang/lib/AST/ByteCode/Compiler.cpp
@@ -559,8 +559,7 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
     // Possibly diagnose casts to enum types if the target type does not
     // have a fixed size.
     if (Ctx.getLangOpts().CPlusPlus && CE->getType()->isEnumeralType()) {
-      const auto *ET = CE->getType().getCanonicalType()->castAs<EnumType>();
-      const auto *ED = ET->getOriginalDecl()->getDefinitionOrSelf();
+      const auto *ED = CE->getType()->castAsEnumDecl();
       if (!ED->isFixed()) {
         if (!this->emitCheckEnumValue(*FromT, ED, CE))
           return false;
@@ -4608,8 +4607,8 @@ UnsignedOrNone Compiler<Emitter>::allocateTemporary(const Expr *E) {
 template <class Emitter>
 const RecordType *Compiler<Emitter>::getRecordTy(QualType Ty) {
   if (const PointerType *PT = dyn_cast<PointerType>(Ty))
-    return PT->getPointeeType()->getAs<RecordType>();
-  return Ty->getAs<RecordType>();
+    return PT->getPointeeType()->getAsCanonical<RecordType>();
+  return Ty->getAsCanonical<RecordType>();
 }
 
 template <class Emitter> Record *Compiler<Emitter>::getRecord(QualType Ty) {
diff --git a/clang/lib/AST/ByteCode/Context.cpp b/clang/lib/AST/ByteCode/Context.cpp
index 36eb7607e70bf..fbbb508ed226c 100644
--- a/clang/lib/AST/ByteCode/Context.cpp
+++ b/clang/lib/AST/ByteCode/Context.cpp
@@ -364,8 +364,7 @@ OptPrimType Context::classify(QualType T) const {
     return integralTypeToPrimTypeU(BT->getNumBits());
   }
 
-  if (const auto *ET = T->getAs<EnumType>()) {
-    const auto *D = ET->getOriginalDecl()->getDefinitionOrSelf();
+  if (const auto *D = T->getAsEnumDecl()) {
     if (!D->isComplete())
       return std::nullopt;
     return classify(D->getIntegerType());
diff --git a/clang/lib/AST/ByteCode/InterpBuiltin.cpp b/clang/lib/AST/ByteCode/InterpBuiltin.cpp
index 2e28814abfddb..e245f24e614e5 100644
--- a/clang/lib/AST/ByteCode/InterpBuiltin.cpp
+++ b/clang/lib/AST/ByteCode/InterpBuiltin.cpp
@@ -3303,11 +3303,8 @@ bool InterpretOffsetOf(InterpState &S, CodePtr OpPC, const OffsetOfExpr *E,
     switch (Node.getKind()) {
     case OffsetOfNode::Field: {
       const FieldDecl *MemberDecl = Node.getField();
-      const RecordType *RT = CurrentType->getAs<RecordType>();
-      if (!RT)
-        return false;
-      const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
-      if (RD->isInvalidDecl())
+      const auto *RD = CurrentType->getAsRecordDecl();
+      if (!RD || RD->isInvalidDecl())
         return false;
       const ASTRecordLayout &RL = S.getASTContext().getASTRecordLayout(RD);
       unsigned FieldIndex = MemberDecl->getFieldIndex();
@@ -3336,23 +3333,19 @@ bool InterpretOffsetOf(InterpState &S, CodePtr OpPC, const OffsetOfExpr *E,
         return false;
 
       // Find the layout of the class whose base we are looking into.
-      const RecordType *RT = CurrentType->getAs<RecordType>();
-      if (!RT)
-        return false;
-      const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
-      if (RD->isInvalidDecl())
+      const auto *RD = CurrentType->getAsCXXRecordDecl();
+      if (!RD || RD->isInvalidDecl())
         return false;
       const ASTRecordLayout &RL = S.getASTContext().getASTRecordLayout(RD);
 
       // Find the base class itself.
       CurrentType = BaseSpec->getType();
-      const RecordType *BaseRT = CurrentType->getAs<RecordType>();
-      if (!BaseRT)
+      const auto *BaseRD = CurrentType->getAsCXXRecordDecl();
+      if (!BaseRD)
         return false;
 
       // Add the offset to the base.
-      Result += RL.getBaseClassOffset(cast<CXXRecordDecl>(
-          BaseRT->getOriginalDecl()->getDefinitionOrSelf()));
+      Result += RL.getBaseClassOffset(BaseRD);
       break;
     }
     case OffsetOfNode::Identifier:
diff --git a/clang/lib/AST/ByteCode/Pointer.cpp b/clang/lib/AST/ByteCode/Pointer.cpp
index 38856ad256a63..bd4d440e29347 100644
--- a/clang/lib/AST/ByteCode/Pointer.cpp
+++ b/clang/lib/AST/ByteCode/Pointer.cpp
@@ -700,7 +700,7 @@ std::optional<APValue> Pointer::toRValue(const Context &Ctx,
       return true;
     }
 
-    if (const auto *RT = Ty->getAs<RecordType>()) {
+    if (const auto *RT = Ty->getAsCanonical<RecordType>()) {
       const auto *Record = Ptr.getRecord();
       assert(Record && "Missing record descriptor");
 
diff --git a/clang/lib/AST/ByteCode/Program.cpp b/clang/lib/AST/ByteCode/Program.cpp
index 5d72044af969e..0892d1dc36595 100644
--- a/clang/lib/AST/ByteCode/Program.cpp
+++ b/clang/lib/AST/ByteCode/Program.cpp
@@ -332,10 +332,9 @@ Record *Program::getOrCreateRecord(const RecordDecl *RD) {
         continue;
 
       // In error cases, the base might not be a RecordType.
-      const auto *RT = Spec.getType()->getAs<RecordType>();
-      if (!RT)
+      const auto *BD = Spec.getType()->getAsCXXRecordDecl();
+      if (!BD)
         return nullptr;
-      const RecordDecl *BD = RT->getOriginalDecl()->getDefinitionOrSelf();
       const Record *BR = getOrCreateRecord(BD);
 
       const Descriptor *Desc = GetBaseDesc(BD, BR);
@@ -348,11 +347,7 @@ Record *Program::getOrCreateRecord(const RecordDecl *RD) {
     }
 
     for (const CXXBaseSpecifier &Spec : CD->vbases()) {
-      const auto *RT = Spec.getType()->getAs<RecordType>();
-      if (!RT)
-        return nullptr;
-
-      const RecordDecl *BD = RT->getOriginalDecl()->getDefinitionOrSelf();
+      const auto *BD = Spec.getType()->castAsCXXRecordDecl();
       const Record *BR = getOrCreateRecord(BD);
 
       const Descriptor *Desc = GetBaseDesc(BD, BR);
@@ -408,9 +403,8 @@ Descriptor *Program::createDescriptor(const DeclTy &D, const Type *Ty,
                                       const Expr *Init) {
 
   // Classes and structures.
-  if (const auto *RT = Ty->getAs<RecordType>()) {
-    if (const auto *Record =
-            getOrCreateRecord(RT->getOriginalDecl()->getDefinitionOrSelf()))
+  if (const auto *RD = Ty->getAsRecordDecl()) {
+    if (const auto *Record = getOrCreateRecord(RD))
       return allocateDescriptor(D, Record, MDSize, IsConst, IsTemporary,
                                 IsMutable, IsVolatile);
     return allocateDescriptor(D, MDSize);
diff --git a/clang/lib/AST/ByteCode/Record.cpp b/clang/lib/AST/ByteCode/Record.cpp
index a7934ccb4e55e..c20ec184f34f4 100644
--- a/clang/lib/AST/ByteCode/Record.cpp
+++ b/clang/lib/AST/ByteCode/Record.cpp
@@ -50,10 +50,8 @@ const Record::Base *Record::getBase(const RecordDecl *FD) const {
 }
 
 const Record::Base *Record::getBase(QualType T) const {
-  if (auto *RT = T->getAs<RecordType>()) {
-    const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
+  if (auto *RD = T->getAsCXXRecordDecl())
     return BaseMap.lookup(RD);
-  }
   return nullptr;
 }
 
diff --git a/clang/lib/AST/CXXInheritance.cpp b/clang/lib/AST/CXXInheritance.cpp
index e4b77edc063dc..b60d445ad3414 100644
--- a/clang/lib/AST/CXXInheritance.cpp
+++ b/clang/lib/AST/CXXInheritance.cpp
@@ -128,17 +128,11 @@ bool CXXRecordDecl::forallBases(ForallBasesCallback BaseMatches) const {
   const CXXRecordDecl *Record = this;
   while (true) {
     for (const auto &I : Record->bases()) {
-      const RecordType *Ty = I.getType()->getAs<RecordType>();
-      if (!Ty)
+      const auto *Base = I.getType()->getAsCXXRecordDecl();
+      if (!Base || !(Base->isBeingDefined() || Base->isCompleteDefinition()))
         return false;
-
-      CXXRecordDecl *Base = cast_if_present<CXXRecordDecl>(
-          Ty->getOriginalDecl()->getDefinition());
-      if (!Base ||
-          (Base->isDependentContext() &&
-           !Base->isCurrentInstantiation(Record))) {
+      if (Base->isDependentContext() && !Base->isCurrentInstantiation(Record))
         return false;
-      }
 
       Queue.push_back(Base);
       if (!BaseMatches(Base))
@@ -196,7 +190,7 @@ bool CXXBasePaths::lookupInBases(ASTContext &Context,
       if (isDetectingVirtual() && DetectedVirtual == nullptr) {
         // If this is the first virtual we find, remember it. If it turns out
         // there is no base path here, we'll reset it later.
-        DetectedVirtual = BaseType->getAs<RecordType>();
+        DetectedVirtual = BaseType->getAsCanonical<RecordType>();
         SetVirtual = true;
       }
     } else {
@@ -255,9 +249,7 @@ bool CXXBasePaths::lookupInBases(ASTContext &Context,
         const TemplateSpecializationType *TST =
             BaseSpec.getType()->getAs<TemplateSpecializationType>();
         if (!TST) {
-          if (auto *RT = BaseSpec.getType()->getAs<RecordType>())
-            BaseRecord = cast<CXXRecordDecl>(RT->getOriginalDecl())
-                             ->getDefinitionOrSelf();
+          BaseRecord = BaseSpec.getType()->getAsCXXRecordDecl();
         } else {
           TemplateName TN = TST->getTemplateName();
           if (auto *TD =
@@ -271,7 +263,7 @@ bool CXXBasePaths::lookupInBases(ASTContext &Context,
             BaseRecord = nullptr;
         }
       } else {
-        BaseRecord = cast<CXXRecordDecl>(BaseSpec.getType()->getAsRecordDecl());
+        BaseRecord = BaseSpec.getType()->getAsCXXRecordDecl();
       }
       if (BaseRecord &&
           lookupInBases(Context, BaseRecord, BaseMatches, LookupInDependent)) {
@@ -335,10 +327,7 @@ bool CXXRecordDecl::lookupInBases(BaseMatchesCallback BaseMatches,
       if (!PE.Base->isVirtual())
         continue;
 
-      CXXRecordDecl *VBase = nullptr;
-      if (const RecordType *Record = PE.Base->getType()->getAs<RecordType>())
-        VBase = cast<CXXRecordDecl>(Record->getOriginalDecl())
-                    ->getDefinitionOrSelf();
+      auto *VBase = PE.Base->getType()->getAsCXXRecordDecl();
       if (!VBase)
         break;
 
@@ -347,11 +336,8 @@ bool CXXRecordDecl::lookupInBases(BaseMatchesCallback BaseMatches,
       // base is a subobject of any other path; if so, then the
       // declaration in this path are hidden by that patch.
       for (const CXXBasePath &HidingP : Paths) {
-        CXXRecordDecl *HidingClass = nullptr;
-        if (const RecordType *Record =
-                HidingP.back().Base->getType()->getAs<RecordType>())
-          HidingClass = cast<CXXRecordDecl>(Record->getOriginalDecl())
-                            ->getDefinitionOrSelf();
+        auto *HidingClass =
+            HidingP.back().Base->getType()->getAsCXXRecordDecl();
         if (!HidingClass)
           break;
 
@@ -407,7 +393,7 @@ bool CXXRecordDecl::hasMemberName(DeclarationName Name) const {
   CXXBasePaths Paths(false, false, false);
   return lookupInBases(
       [Name](const CXXBaseSpecifier *Specifier, CXXBasePath &Path) {
-        return findOrdinaryMember(Specifier->getType()->getAsCXXRecordDecl(),
+        return findOrdinaryMember(Specifier->getType()->castAsCXXRecordDecl(),
                                   Path, Name);
       },
       Paths);
@@ -470,9 +456,7 @@ void FinalOverriderCollector::Collect(const CXXRecordDecl *RD,
       = ++SubobjectCount[cast<CXXRecordDecl>(RD->getCanonicalDecl())];
 
   for (const auto &Base : RD->bases()) {
-    if (const RecordType *RT = Base.getType()->getAs<RecordType>()) {
-      const CXXRecordDecl *BaseDecl =
-          cast<CXXRecordDecl>(RT->getOriginalDecl())->getDefinitionOrSelf();
+    if (const auto *BaseDecl = Base.getType()->getAsCXXRecordDecl()) {
       if (!BaseDecl->isPolymorphic())
         continue;
 
diff --git a/clang/lib/AST/Decl.cpp b/clang/lib/AST/Decl.cpp
index 4507f415ce606..343673069e15e 100644
--- a/clang/lib/AST/Decl.cpp
+++ b/clang/lib/AST/Decl.cpp
@@ -2861,9 +2861,8 @@ VarDecl::needsDestruction(const ASTContext &Ctx) const {
 
 bool VarDecl::hasFlexibleArrayInit(const ASTContext &Ctx) const {
   assert(hasInit() && "Expect initializer to check for flexible array init");
-  auto *Ty = getType()->getAs<RecordType>();
-  if (!Ty ||
-      !Ty->getOriginalDecl()->getDefinitionOrSelf()->hasFlexibleArrayMember())
+  auto *D = getType()->getAsRecordDecl();
+  if (!D || !D->hasFlexibleArrayMember())
     return false;
   auto *List = dyn_cast<InitListExpr>(getInit()->IgnoreParens());
   if (!List)
@@ -2877,11 +2876,8 @@ bool VarDecl::hasFlexibleArrayInit(const ASTContext &Ctx) const {
 
 CharUnits VarDecl::getFlexibleArrayInitChars(const ASTContext &Ctx) const {
   assert(hasInit() && "Expect initializer to check for flexible array init");
-  auto *Ty = getType()->getAs<RecordType>();
-  if (!Ty)
-    return CharUnits::Zero();
-  const RecordDecl *RD = Ty->getOriginalDecl()->getDefinitionOrSelf();
-  if (!Ty || !RD->hasFlexibleArrayMember())
+  auto *RD = getType()->getAsRecordDecl();
+  if (!RD || !RD->hasFlexibleArrayMember())
     return CharUnits::Zero();
   auto *List = dyn_cast<InitListExpr>(getInit()->IgnoreParens());
   if (!List || List->getNumInits() == 0)
@@ -2992,7 +2988,7 @@ bool ParmVarDecl::isDestroyedInCallee() const {
 
   // FIXME: isParamDestroyedInCallee() should probably imply
   // isDestructedType()
-  const auto *RT = getType()->getAs<RecordType>();
+  const auto *RT = getType()->getAsCanonical<RecordType>();
   if (RT &&
       RT->getOriginalDecl()
           ->getDefinitionOrSelf()
@@ -3507,7 +3503,7 @@ bool FunctionDecl::isUsableAsGlobalAllocationFunctionInConstantEvaluation(
     while (const auto *TD = T->getAs<TypedefType>())
       T = TD->getDecl()->getUnderlyingType();
     const IdentifierInfo *II =
-        T->castAs<EnumType>()->getOriginalDecl()->getIdentifier();
+        T->castAsCanonical<EnumType>()->getOriginalDecl()->getIdentifier();
     if (II && II->isStr("__hot_cold_t"))
       Consume();
   }
@@ -4657,7 +4653,7 @@ bool FieldDecl::isAnonymousStructOrUnion() const {
   if (!isImplicit() || getDeclName())
     return false;
 
-  if (const auto *Record = getType()->getAs<RecordType>())
+  if (const auto *Record = getType()->getAsCanonical<RecordType>())
     return Record->getOriginalDecl()->isAnonymousStructOrUnion();
 
   return false;
@@ -4715,7 +4711,7 @@ bool FieldDecl::isZeroSize(const ASTContext &Ctx) const {
     return false;
 
   //     -- is not of class type, or
-  const auto *RT = getType()->getAs<RecordType>();
+  const auto *RT = getType()->getAsCanonical<RecordType>();
   if (!RT)
     return false;
   const RecordDecl *RD = RT->getOriginalDecl()->getDefinition();
@@ -4738,7 +4734,7 @@ bool FieldDecl::isZeroSize(const ASTContext &Ctx) const {
   // MS ABI: has nonzero size if it is a class type with class type fields,
   // whether or not they have nonzero size
   return !llvm::any_of(CXXRD->fields(), [](const FieldDecl *Field) {
-    return Field->getType()->getAs<RecordType>();
+    return Field->getType()->isRecordType();
   });
 }
 
@@ -5142,7 +5138,7 @@ bool RecordDecl::isOrContainsUnion() const {
 
   if (const RecordDecl *Def = getDefinition()) {
     for (const FieldDecl *FD : Def->fields()) {
-      const RecordType *RT = FD->getType()->getAs<RecordType>();
+      const RecordType *RT = FD->getType()->getAsCanonical<RecordType>();
       if (RT && RT->getOriginalDecl()->isOrContainsUnion())
         return true;
     }
@@ -5274,10 +5270,8 @@ const FieldDecl *RecordDecl::findFirstNamedDataMember() const {
     if (I->getIdentifier())
       return I;
 
-    if (const auto *RT = I->getType()->getAs<RecordType>())
-      if (const FieldDecl *NamedDataMember = RT->getOriginalDecl()
-                                                 ->getDefinitionOrSelf()
-                                                 ->findFirstNamedDataMember())
+    if (const auto *RD = I->getType()->getAsRecordDecl())
+      if (const FieldDecl *NamedDataMember = RD->findFirstNamedDataMember())
         return NamedDataMember;
   }
 
diff --git a/clang/lib/AST/DeclCXX.cpp b/clang/lib/AST/DeclCXX.cpp
index 50b1a1d000090..97fdbe3c42bc8 100644
--- a/clang/lib/AST/DeclCXX.cpp
+++ b/clang/lib/AST/DeclCXX.cpp
@@ -216,9 +216,7 @@ CXXRecordDecl::setBases(CXXBaseSpecifier const * const *Bases,
     // Skip dependent types; we can't do any checking on them now.
     if (BaseType->isDependentType())
       continue;
-    auto *BaseClassDecl =
-        cast<CXXRecordDecl>(BaseType->castAs<RecordType>()->getOriginalDecl())
-            ->getDefinitionOrSelf();
+    auto *BaseClassDecl = BaseType->castAsCXXRecordDecl();
 
     // C++2a [class]p7:
     //   A standard-layout class is a class that:
@@ -1207,9 +1205,8 @@ void CXXRecordDecl::addedMember(Decl *D) {
     // those because they are always unnamed.
     bool IsZeroSize = Field->isZeroSize(Context);
 
-    if (const auto *RecordTy = T->getAs<RecordType>()) {
-      auto *FieldRec = cast<CXXRecordDecl>(RecordTy->getOriginalDecl());
-      if (FieldRec->getDefinition()) {
+    if (auto *FieldRec = T->getAsCXXRecordDecl()) {
+      if (FieldRec->isBeingDefined() || FieldRec->isCompleteDefinition()) {
         addedClassSubobject(FieldRec);
 
         // We may need to perform overload resolution to determine whether a
@@ -1908,15 +1905,14 @@ static void CollectVisibleConversions(
 
   // Collect information recursively from any base classes.
   for (const auto &I : Record->bases()) {
-    const auto *RT = I.getType()->getAs<RecordType>();
-    if (!RT) continue;
+    const auto *Base = I.getType()->getAsCXXRecordDecl();
+    if (!Base)
+      continue;
 
     AccessSpecifier BaseAccess
       = CXXRecordDecl::MergeAccess(Access, I.getAccessSpecifier());
     bool BaseInVirtual = InVirtual || I.isVirtual();
 
-    auto *Base =
-        cast<CXXRecordDecl>(RT->getOriginalDecl())->getDefinitionOrSelf();
     CollectVisibleConversions(Context, Base, BaseInVirtual, BaseAccess,
                               *HiddenTypes, Output, VOutput, HiddenVBaseCs);
   }
@@ -1951,14 +1947,13 @@ static void CollectVisibleConversions(ASTContext &Context,
 
   // Recursively collect conversions from base classes.
   for (const auto &I : Record->bases()) {
-    const auto *RT = I.getType()->getAs<RecordType>();
-    if (!RT) continue;
+    const auto *Base = I.getType()->getAsCXXRecordDecl();
+    if (!Base)
+      continue;
 
-    CollectVisibleConversions(
-        Context,
-        cast<CXXRecordDecl>(RT->getOriginalDecl())->getDefinitionOrSelf(),
-        I.isVirtual(), I.getAccessSpecifier(), HiddenTypes, Output, VBaseCs,
-        HiddenVBaseCs);
+    CollectVisibleConversions(Context, Base, I.isVirtual(),
+                              I.getAccessSpecifier(), HiddenTypes, Output,
+                              VBaseCs, HiddenVBaseCs);
   }
 
   // Add any unhidden conversions provided by virtual bases.
@@ -2312,7 +2307,7 @@ bool CXXRecordDecl::mayBeAbstract() const {
 
   for (const auto &B : bases()) {
     const auto *BaseDecl = cast<CXXRecordDecl>(
-        B.getType()->castAs<RecordType>()->getOriginalDecl());
+        B.getType()->castAsCanonical<RecordType>()->getOriginalDecl());
     if (BaseDecl->isAbstract())
       return true;
   }
@@ -2472,11 +2467,9 @@ CXXMethodDecl::getCorrespondingMethodInClass(const CXXRecordDecl *RD,
   };
 
   for (const auto &I : RD->bases()) {
-    const RecordType *RT = I.getType()->getAs<RecordType>();
-    if (!RT)
+    const auto *Base = I.getType()->getAsCXXRecordDecl();
+    if (!Base)
       continue;
-    const auto *Base =
-        cast<CXXRecordDecl>(RT->getOriginalDecl())->getDefinitionOrSelf();
     if (CXXMethodDecl *D = this->getCorrespondingMethodInClass(Base))
       AddFinalOverrider(D);
   }
@@ -3437,13 +3430,12 @@ SourceRange UsingDecl::getSourceRange() const {
 void UsingEnumDecl::anchor() {}
 
 UsingEnumDecl *UsingEnumDecl::Create(ASTContext &C, DeclContext *DC,
-                                     SourceLocation UL,
-                                     SourceLocation EL,
+                                     SourceLocation UL, SourceLocation EL,
                                      SourceLocation NL,
                                      TypeSourceInfo *EnumType) {
-  assert(isa<EnumDecl>(EnumType->getType()->getAsTagDecl()));
   return new (C, DC)
-      UsingEnumDecl(DC, EnumType->getType()->getAsTagDecl()->getDeclName(), UL, EL, NL, EnumType);
+      UsingEnumDecl(DC, EnumType->getType()->castAsEnumDecl()->getDeclName(),
+                    UL, EL, NL, EnumType);
 }
 
 UsingEnumDecl *UsingEnumDecl::CreateDeserialized(ASTContext &C,
diff --git a/clang/lib/AST/DeclTemplate.cpp b/clang/lib/AST/DeclTemplate.cpp
index b91b4670c63a3..3162857aac5d0 100644
--- a/clang/lib/AST/DeclTemplate.cpp
+++ b/clang/lib/AST/DeclTemplate.cpp
@@ -730,15 +730,15 @@ void TemplateTypeParmDecl::setDefaultArgument(
 }
 
 unsigned TemplateTypeParmDecl::getDepth() const {
-  return getTypeForDecl()->castAs<TemplateTypeParmType>()->getDepth();
+  return dyn_cast<TemplateTypeParmType>(getTypeForDecl())->getDepth();
 }
 
 unsigned TemplateTypeParmDecl::getIndex() const {
-  return getTypeForDecl()->castAs<TemplateTypeParmType>()->getIndex();
+  return dyn_cast<TemplateTypeParmType>(getTypeForDecl())->getIndex();
 }
 
 bool TemplateTypeParmDecl::isParameterPack() const {
-  return getTypeForDecl()->castAs<TemplateTypeParmType>()->isParameterPack();
+  return dyn_cast<TemplateTypeParmType>(getTypeForDecl())->isParameterPack();
 }
 
 void TemplateTypeParmDecl::setTypeConstraint(
diff --git a/clang/lib/AST/DeclarationName.cpp b/clang/lib/AST/DeclarationName.cpp
index 6c7b995d57567..74f2eec69461a 100644
--- a/clang/lib/AST/DeclarationName.cpp
+++ b/clang/lib/AST/DeclarationName.cpp
@@ -114,12 +114,12 @@ static void printCXXConstructorDestructorName(QualType ClassType,
   // We know we're printing C++ here. Ensure we print types properly.
   Policy.adjustForCPlusPlus();
 
-  if (const RecordType *ClassRec = ClassType->getAs<RecordType>()) {
+  if (const RecordType *ClassRec = ClassType->getAsCanonical<RecordType>()) {
     ClassRec->getOriginalDecl()->printName(OS, Policy);
     return;
   }
   if (Policy.SuppressTemplateArgsInCXXConstructors) {
-    if (auto *InjTy = ClassType->getAs<InjectedClassNameType>()) {
+    if (auto *InjTy = ClassType->getAsCanonical<InjectedClassNameType>()) {
       InjTy->getOriginalDecl()->printName(OS, Policy);
       return;
     }
diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp
index 340de6d4be934..ba589ac3c0dfe 100644
--- a/clang/lib/AST/Expr.cpp
+++ b/clang/lib/AST/Expr.cpp
@@ -74,8 +74,7 @@ const CXXRecordDecl *Expr::getBestDynamicClassType() const {
   if (DerivedType->isDependentType())
     return nullptr;
 
-  const RecordType *Ty = DerivedType->castAs<RecordType>();
-  return cast<CXXRecordDecl>(Ty->getOriginalDecl())->getDefinitionOrSelf();
+  return DerivedType->castAsCXXRecordDecl();
 }
 
 const Expr *Expr::skipRValueSubobjectAdjustments(
@@ -90,10 +89,7 @@ const Expr *Expr::skipRValueSubobjectAdjustments(
            CE->getCastKind() == CK_UncheckedDerivedToBase) &&
           E->getType()->isRecordType()) {
         E = CE->getSubExpr();
-        const auto *Derived =
-            cast<CXXRecordDecl>(
-                E->getType()->castAs<RecordType>()->getOriginalDecl())
-                ->getDefinitionOrSelf();
+        const auto *Derived = E->getType()->castAsCXXRecordDecl();
         Adjustments.push_back(SubobjectAdjustment(CE, Derived));
         continue;
       }
@@ -2032,9 +2028,7 @@ CXXBaseSpecifier **CastExpr::path_buffer() {
 
 const FieldDecl *CastExpr::getTargetFieldForToUnionCast(QualType unionType,
                                                         QualType opType) {
-  auto RD =
-      unionType->castAs<RecordType>()->getOriginalDecl()->getDefinitionOrSelf();
-  return getTargetFieldForToUnionCast(RD, opType);
+  return getTargetFieldForToUnionCast(unionType->castAsRecordDecl(), opType);
 }
 
 const FieldDecl *CastExpr::getTargetFieldForToUnionCast(const RecordDecl *RD,
@@ -3396,10 +3390,7 @@ bool Expr::isConstantInitializer(ASTContext &Ctx, bool IsForRef,
 
     if (ILE->getType()->isRecordType()) {
       unsigned ElementNo = 0;
-      RecordDecl *RD = ILE->getType()
-                           ->castAs<RecordType>()
-                           ->getOriginalDecl()
-                           ->getDefinitionOrSelf();
+      auto *RD = ILE->getType()->castAsRecordDecl();
 
       // In C++17, bases were added to the list of members used by aggregate
       // initialization.
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 9b934753bcc3c..5803931c4a424 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -2623,8 +2623,7 @@ static bool CheckEvaluationResult(CheckEvaluationResultKind CERK,
         Value.getUnionValue(), Kind, Value.getUnionField(), CheckedTemps);
   }
   if (Value.isStruct()) {
-    RecordDecl *RD =
-        Type->castAs<RecordType>()->getOriginalDecl()->getDefinitionOrSelf();
+    auto *RD = Type->castAsRecordDecl();
     if (const CXXRecordDecl *CD = dyn_cast<CXXRecordDecl>(RD)) {
       unsigned BaseIndex = 0;
       for (const CXXBaseSpecifier &BS : CD->bases()) {
@@ -4110,7 +4109,8 @@ findSubobject(EvalInfo &Info, const Expr *E, const CompleteObject &Obj,
       }
 
       // Next subobject is a class, struct or union field.
-      RecordDecl *RD = ObjType->castAs<RecordType>()->getOriginalDecl();
+      RecordDecl *RD =
+          ObjType->castAsCanonical<RecordType>()->getOriginalDecl();
       if (RD->isUnion()) {
         const FieldDecl *UnionField = O->getUnionField();
         if (!UnionField ||
@@ -8591,10 +8591,9 @@ class ExprEvaluatorBase
     const FieldDecl *FD = dyn_cast<FieldDecl>(E->getMemberDecl());
     if (!FD) return Error(E);
     assert(!FD->getType()->isReferenceType() && "prvalue reference?");
-    assert(
-        BaseTy->castAs<RecordType>()->getOriginalDecl()->getCanonicalDecl() ==
-            FD->getParent()->getCanonicalDecl() &&
-        "record / field mismatch");
+    assert(BaseTy->castAsCanonical<RecordType>()->getOriginalDecl() ==
+               FD->getParent()->getCanonicalDecl() &&
+           "record / field mismatch");
 
     // Note: there is no lvalue base here. But this case should only ever
     // happen in C or in C++98, where we cannot be evaluating a constexpr
@@ -8821,10 +8820,9 @@ class LValueExprEvaluatorBase
 
     const ValueDecl *MD = E->getMemberDecl();
     if (const FieldDecl *FD = dyn_cast<FieldDecl>(E->getMemberDecl())) {
-      assert(
-          BaseTy->castAs<RecordType>()->getOriginalDecl()->getCanonicalDecl() ==
-              FD->getParent()->getCanonicalDecl() &&
-          "record / field mismatch");
+      assert(BaseTy->castAsCanonical<RecordType>()->getOriginalDecl() ==
+                 FD->getParent()->getCanonicalDecl() &&
+             "record / field mismatch");
       (void)BaseTy;
       if (!HandleLValueMember(this->Info, E, Result, FD))
         return false;
@@ -10824,8 +10822,7 @@ static bool HandleClassZeroInitialization(EvalInfo &Info, const Expr *E,
 }
 
 bool RecordExprEvaluator::ZeroInitialization(const Expr *E, QualType T) {
-  const RecordDecl *RD =
-      T->castAs<RecordType>()->getOriginalDecl()->getDefinitionOrSelf();
+  const auto *RD = T->castAsRecordDecl();
   if (RD->isInvalidDecl()) return false;
   if (RD->isUnion()) {
     // C++11 [dcl.init]p5: If T is a (possibly cv-qualified) union type, the
@@ -10894,10 +10891,7 @@ bool RecordExprEvaluator::VisitInitListExpr(const InitListExpr *E) {
 
 bool RecordExprEvaluator::VisitCXXParenListOrInitListExpr(
     const Expr *ExprToVisit, ArrayRef<Expr *> Args) {
-  const RecordDecl *RD = ExprToVisit->getType()
-                             ->castAs<RecordType>()
-                             ->getOriginalDecl()
-                             ->getDefinitionOrSelf();
+  const auto *RD = ExprToVisit->getType()->castAsRecordDecl();
   if (RD->isInvalidDecl()) return false;
   const ASTRecordLayout &Layout = Info.Ctx.getASTRecordLayout(RD);
   auto *CXXRD = dyn_cast<CXXRecordDecl>(RD);
@@ -11125,10 +11119,7 @@ bool RecordExprEvaluator::VisitCXXStdInitializerListExpr(
   Result = APValue(APValue::UninitStruct(), 0, 2);
   Array.moveInto(Result.getStructField(0));
 
-  RecordDecl *Record = E->getType()
-                           ->castAs<RecordType>()
-                           ->getOriginalDecl()
-                           ->getDefinitionOrSelf();
+  auto *Record = E->getType()->castAsRecordDecl();
   RecordDecl::field_iterator Field = Record->field_begin();
   assert(Field != Record->field_end() &&
          Info.Ctx.hasSameType(Field->getType()->getPointeeType(),
@@ -13172,10 +13163,7 @@ static bool convertUnsignedAPIntToCharUnits(const llvm::APInt &Int,
 static void addFlexibleArrayMemberInitSize(EvalInfo &Info, const QualType &T,
                                            const LValue &LV, CharUnits &Size) {
   if (!T.isNull() && T->isStructureType() &&
-      T->getAsStructureType()
-          ->getOriginalDecl()
-          ->getDefinitionOrSelf()
-          ->hasFlexibleArrayMember())
+      T->castAsRecordDecl()->hasFlexibleArrayMember())
     if (const auto *V = LV.getLValueBase().dyn_cast<const ValueDecl *>())
       if (const auto *VD = dyn_cast<VarDecl>(V))
         if (VD->hasInit())
@@ -15459,10 +15447,9 @@ bool IntExprEvaluator::VisitOffsetOfExpr(const OffsetOfExpr *OOE) {
 
     case OffsetOfNode::Field: {
       FieldDecl *MemberDecl = ON.getField();
-      const RecordType *RT = CurrentType->getAs<RecordType>();
-      if (!RT)
+      const auto *RD = CurrentType->getAsRecordDecl();
+      if (!RD)
         return Error(OOE);
-      RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
       if (RD->isInvalidDecl()) return false;
       const ASTRecordLayout &RL = Info.Ctx.getASTRecordLayout(RD);
       unsigned i = MemberDecl->getFieldIndex();
@@ -15481,22 +15468,20 @@ bool IntExprEvaluator::VisitOffsetOfExpr(const OffsetOfExpr *OOE) {
         return Error(OOE);
 
       // Find the layout of the class whose base we are looking into.
-      const RecordType *RT = CurrentType->getAs<RecordType>();
-      if (!RT)
+      const auto *RD = CurrentType->getAsCXXRecordDecl();
+      if (!RD)
         return Error(OOE);
-      RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
       if (RD->isInvalidDecl()) return false;
       const ASTRecordLayout &RL = Info.Ctx.getASTRecordLayout(RD);
 
       // Find the base class itself.
       CurrentType = BaseSpec->getType();
-      const RecordType *BaseRT = CurrentType->getAs<RecordType>();
-      if (!BaseRT)
+      const auto *BaseRD = CurrentType->getAsCXXRecordDecl();
+      if (!BaseRD)
         return Error(OOE);
 
       // Add the offset to the base.
-      Result += RL.getBaseClassOffset(cast<CXXRecordDecl>(
-          BaseRT->getOriginalDecl()->getDefinitionOrSelf()));
+      Result += RL.getBaseClassOffset(BaseRD);
       break;
     }
     }
@@ -15674,8 +15659,7 @@ bool IntExprEvaluator::VisitCastExpr(const CastExpr *E) {
     }
 
     if (Info.Ctx.getLangOpts().CPlusPlus && DestType->isEnumeralType()) {
-      const EnumType *ET = dyn_cast<EnumType>(DestType.getCanonicalType());
-      const EnumDecl *ED = ET->getOriginalDecl()->getDefinitionOrSelf();
+      const auto *ED = DestType->getAsEnumDecl();
       // Check that the value is within the range of the enumeration values.
       //
       // This corressponds to [expr.static.cast]p10 which says:
diff --git a/clang/lib/AST/FormatString.cpp b/clang/lib/AST/FormatString.cpp
index 502a3e6b145e3..62a02dd87a150 100644
--- a/clang/lib/AST/FormatString.cpp
+++ b/clang/lib/AST/FormatString.cpp
@@ -413,14 +413,13 @@ ArgType::matchesType(ASTContext &C, QualType argTy) const {
       return Match;
 
     case AnyCharTy: {
-      if (const auto *ETy = argTy->getAs<EnumType>()) {
+      if (const auto *ED = argTy->getAsEnumDecl()) {
         // If the enum is incomplete we know nothing about the underlying type.
         // Assume that it's 'int'. Do not use the underlying type for a scoped
         // enumeration.
-        const EnumDecl *ED = ETy->getOriginalDecl()->getDefinitionOrSelf();
         if (!ED->isComplete())
           return NoMatch;
-        if (ETy->isUnscopedEnumerationType())
+        if (!ED->isScoped())
           argTy = ED->getIntegerType();
       }
 
@@ -463,14 +462,13 @@ ArgType::matchesType(ASTContext &C, QualType argTy) const {
         return matchesSizeTPtrdiffT(C, argTy, T);
       }
 
-      if (const EnumType *ETy = argTy->getAs<EnumType>()) {
+      if (const auto *ED = argTy->getAsEnumDecl()) {
         // If the enum is incomplete we know nothing about the underlying type.
         // Assume that it's 'int'. Do not use the underlying type for a scoped
         // enumeration as that needs an exact match.
-        const EnumDecl *ED = ETy->getOriginalDecl()->getDefinitionOrSelf();
         if (!ED->isComplete())
           argTy = C.IntTy;
-        else if (ETy->isUnscopedEnumerationType())
+        else if (!ED->isScoped())
           argTy = ED->getIntegerType();
       }
 
diff --git a/clang/lib/AST/InheritViz.cpp b/clang/lib/AST/InheritViz.cpp
index c03492c64b161..3c4a5a8e2c4a6 100644
--- a/clang/lib/AST/InheritViz.cpp
+++ b/clang/lib/AST/InheritViz.cpp
@@ -89,8 +89,8 @@ void InheritanceHierarchyWriter::WriteNode(QualType Type, bool FromVirtual) {
   Out << " \"];\n";
 
   // Display the base classes.
-  const auto *Decl = static_cast<const CXXRecordDecl *>(
-      Type->castAs<RecordType>()->getOriginalDecl());
+  const auto *Decl = cast<CXXRecordDecl>(
+      Type->castAsCanonical<RecordType>()->getOriginalDecl());
   for (const auto &Base : Decl->bases()) {
     QualType CanonBaseType = Context.getCanonicalType(Base.getType());
 
diff --git a/clang/lib/AST/ItaniumCXXABI.cpp b/clang/lib/AST/ItaniumCXXABI.cpp
index 43a8bcd9443ff..adef1584fd9b6 100644
--- a/clang/lib/AST/ItaniumCXXABI.cpp
+++ b/clang/lib/AST/ItaniumCXXABI.cpp
@@ -42,8 +42,7 @@ namespace {
 ///
 /// Returns the name of anonymous union VarDecl or nullptr if it is not found.
 static const IdentifierInfo *findAnonymousUnionVarDeclName(const VarDecl& VD) {
-  const auto *RT = VD.getType()->castAs<RecordType>();
-  const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
+  const auto *RD = VD.getType()->castAsRecordDecl();
   assert(RD->isUnion() && "RecordType is expected to be a union.");
   if (const FieldDecl *FD = RD->findFirstNamedDataMember()) {
     return FD->getIdentifier();
diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp
index 112678fb2714a..ffadfce67d631 100644
--- a/clang/lib/AST/ItaniumMangle.cpp
+++ b/clang/lib/AST/ItaniumMangle.cpp
@@ -1580,10 +1580,7 @@ void CXXNameMangler::mangleUnqualifiedName(
 
     if (const VarDecl *VD = dyn_cast<VarDecl>(ND)) {
       // We must have an anonymous union or struct declaration.
-      const RecordDecl *RD = VD->getType()
-                                 ->castAs<RecordType>()
-                                 ->getOriginalDecl()
-                                 ->getDefinitionOrSelf();
+      const auto *RD = VD->getType()->castAsRecordDecl();
 
       // Itanium C++ ABI 5.1.2:
       //
@@ -4729,7 +4726,7 @@ void CXXNameMangler::mangleIntegerLiteral(QualType T,
 
 void CXXNameMangler::mangleMemberExprBase(const Expr *Base, bool IsArrow) {
   // Ignore member expressions involving anonymous unions.
-  while (const auto *RT = Base->getType()->getAs<RecordType>()) {
+  while (const auto *RT = Base->getType()->getAsCanonical<RecordType>()) {
     if (!RT->getOriginalDecl()->isAnonymousStructOrUnion())
       break;
     const auto *ME = dyn_cast<MemberExpr>(Base);
@@ -6997,8 +6994,8 @@ static bool hasMangledSubstitutionQualifiers(QualType T) {
 
 bool CXXNameMangler::mangleSubstitution(QualType T) {
   if (!hasMangledSubstitutionQualifiers(T)) {
-    if (const RecordType *RT = T->getAs<RecordType>())
-      return mangleSubstitution(RT->getOriginalDecl()->getDefinitionOrSelf());
+    if (const auto *RD = T->getAsCXXRecordDecl())
+      return mangleSubstitution(RD);
   }
 
   uintptr_t TypePtr = reinterpret_cast<uintptr_t>(T.getAsOpaquePtr());
@@ -7034,7 +7031,7 @@ bool CXXNameMangler::isSpecializedAs(QualType S, llvm::StringRef Name,
   if (S.isNull())
     return false;
 
-  const RecordType *RT = S->getAs<RecordType>();
+  const RecordType *RT = S->getAsCanonical<RecordType>();
   if (!RT)
     return false;
 
@@ -7168,8 +7165,8 @@ bool CXXNameMangler::mangleStandardSubstitution(const NamedDecl *ND) {
 
 void CXXNameMangler::addSubstitution(QualType T) {
   if (!hasMangledSubstitutionQualifiers(T)) {
-    if (const RecordType *RT = T->getAs<RecordType>()) {
-      addSubstitution(RT->getOriginalDecl()->getDefinitionOrSelf());
+    if (const auto *RD = T->getAsCXXRecordDecl()) {
+      addSubstitution(RD);
       return;
     }
   }
diff --git a/clang/lib/AST/JSONNodeDumper.cpp b/clang/lib/AST/JSONNodeDumper.cpp
index b3f12a1cce2ec..ca8e2af284c2b 100644
--- a/clang/lib/AST/JSONNodeDumper.cpp
+++ b/clang/lib/AST/JSONNodeDumper.cpp
@@ -396,7 +396,7 @@ llvm::json::Array JSONNodeDumper::createCastPath(const CastExpr *C) {
   for (auto I = C->path_begin(), E = C->path_end(); I != E; ++I) {
     const CXXBaseSpecifier *Base = *I;
     const auto *RD = cast<CXXRecordDecl>(
-        Base->getType()->castAs<RecordType>()->getOriginalDecl());
+        Base->getType()->castAsCanonical<RecordType>()->getOriginalDecl());
 
     llvm::json::Object Val{{"name", RD->getName()}};
     if (Base->isVirtual())
diff --git a/clang/lib/AST/PrintfFormatString.cpp b/clang/lib/AST/PrintfFormatString.cpp
index 687160c6116be..855550475721a 100644
--- a/clang/lib/AST/PrintfFormatString.cpp
+++ b/clang/lib/AST/PrintfFormatString.cpp
@@ -793,8 +793,8 @@ bool PrintfSpecifier::fixType(QualType QT, const LangOptions &LangOpt,
   }
 
   // If it's an enum, get its underlying type.
-  if (const EnumType *ETy = QT->getAs<EnumType>())
-    QT = ETy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = QT->getAsEnumDecl())
+    QT = ED->getIntegerType();
 
   const BuiltinType *BT = QT->getAs<BuiltinType>();
   if (!BT) {
diff --git a/clang/lib/AST/RecordLayoutBuilder.cpp b/clang/lib/AST/RecordLayoutBuilder.cpp
index f1f21f426f944..4b312c559bf2a 100644
--- a/clang/lib/AST/RecordLayoutBuilder.cpp
+++ b/clang/lib/AST/RecordLayoutBuilder.cpp
@@ -204,15 +204,13 @@ void EmptySubobjectMap::ComputeEmptySubobjectSizes() {
 
   // Check the fields.
   for (const FieldDecl *FD : Class->fields()) {
-    const RecordType *RT =
-        Context.getBaseElementType(FD->getType())->getAs<RecordType>();
-
-    // We only care about record types.
-    if (!RT)
+    // We only care about records.
+    const auto *MemberDecl =
+        Context.getBaseElementType(FD->getType())->getAsCXXRecordDecl();
+    if (!MemberDecl)
       continue;
 
     CharUnits EmptySize;
-    const CXXRecordDecl *MemberDecl = RT->getAsCXXRecordDecl();
     const ASTRecordLayout &Layout = Context.getASTRecordLayout(MemberDecl);
     if (MemberDecl->isEmpty()) {
       // If the class decl is empty, get its size.
@@ -433,11 +431,10 @@ EmptySubobjectMap::CanPlaceFieldSubobjectAtOffset(const FieldDecl *FD,
   // If we have an array type we need to look at every element.
   if (const ConstantArrayType *AT = Context.getAsConstantArrayType(T)) {
     QualType ElemTy = Context.getBaseElementType(AT);
-    const RecordType *RT = ElemTy->getAs<RecordType>();
-    if (!RT)
+    const auto *RD = ElemTy->getAsCXXRecordDecl();
+    if (!RD)
       return true;
 
-    const CXXRecordDecl *RD = RT->getAsCXXRecordDecl();
     const ASTRecordLayout &Layout = Context.getASTRecordLayout(RD);
 
     uint64_t NumElements = Context.getConstantArrayElementCount(AT);
@@ -533,11 +530,10 @@ void EmptySubobjectMap::UpdateEmptyFieldSubobjects(
   // If we have an array type we need to update every element.
   if (const ConstantArrayType *AT = Context.getAsConstantArrayType(T)) {
     QualType ElemTy = Context.getBaseElementType(AT);
-    const RecordType *RT = ElemTy->getAs<RecordType>();
-    if (!RT)
+    const auto *RD = ElemTy->getAsCXXRecordDecl();
+    if (!RD)
       return;
 
-    const CXXRecordDecl *RD = RT->getAsCXXRecordDecl();
     const ASTRecordLayout &Layout = Context.getASTRecordLayout(RD);
 
     uint64_t NumElements = Context.getConstantArrayElementCount(AT);
@@ -2011,7 +2007,7 @@ void ItaniumRecordLayoutBuilder::LayoutField(const FieldDecl *D,
           CTy->getElementType()->castAs<BuiltinType>());
     } else if (const BuiltinType *BTy = BaseTy->getAs<BuiltinType>()) {
       performBuiltinTypeAlignmentUpgrade(BTy);
-    } else if (const RecordType *RT = BaseTy->getAs<RecordType>()) {
+    } else if (const RecordType *RT = BaseTy->getAsCanonical<RecordType>()) {
       const RecordDecl *RD = RT->getOriginalDecl();
       const ASTRecordLayout &FieldRecord = Context.getASTRecordLayout(RD);
       PreferredAlign = FieldRecord.getPreferredAlignment();
@@ -2711,8 +2707,9 @@ MicrosoftRecordLayoutBuilder::getAdjustedElementInfo(
     // alignment when it is applied to bitfields.
     Info.Alignment = std::max(Info.Alignment, FieldRequiredAlignment);
   else {
-    if (auto RT =
-            FD->getType()->getBaseElementTypeUnsafe()->getAs<RecordType>()) {
+    if (const auto *RT = FD->getType()
+                             ->getBaseElementTypeUnsafe()
+                             ->getAsCanonical<RecordType>()) {
       auto const &Layout = Context.getASTRecordLayout(RT->getOriginalDecl());
       EndsWithZeroSizedObject = Layout.endsWithZeroSizedObject();
       FieldRequiredAlignment = std::max(FieldRequiredAlignment,
@@ -3695,9 +3692,9 @@ static void DumpRecordLayout(raw_ostream &OS, const RecordDecl *RD,
       Offset + C.toCharUnitsFromBits(LocalFieldOffsetInBits);
 
     // Recursively dump fields of record type.
-    if (auto RT = Field->getType()->getAs<RecordType>()) {
-      DumpRecordLayout(OS, RT->getOriginalDecl()->getDefinitionOrSelf(), C,
-                       FieldOffset, IndentLevel, Field->getName().data(),
+    if (const auto *RD = Field->getType()->getAsRecordDecl()) {
+      DumpRecordLayout(OS, RD, C, FieldOffset, IndentLevel,
+                       Field->getName().data(),
                        /*PrintSizeInfo=*/false,
                        /*IncludeVirtualBases=*/true);
       continue;
diff --git a/clang/lib/AST/ScanfFormatString.cpp b/clang/lib/AST/ScanfFormatString.cpp
index 31c001d025fea..41cf71a3e042d 100644
--- a/clang/lib/AST/ScanfFormatString.cpp
+++ b/clang/lib/AST/ScanfFormatString.cpp
@@ -430,9 +430,8 @@ bool ScanfSpecifier::fixType(QualType QT, QualType RawQT,
   QualType PT = QT->getPointeeType();
 
   // If it's an enum, get its underlying type.
-  if (const EnumType *ETy = PT->getAs<EnumType>()) {
+  if (const auto *ED = PT->getAsEnumDecl()) {
     // Don't try to fix incomplete enums.
-    const EnumDecl *ED = ETy->getOriginalDecl()->getDefinitionOrSelf();
     if (!ED->isComplete())
       return false;
     PT = ED->getIntegerType();
diff --git a/clang/lib/AST/TemplateBase.cpp b/clang/lib/AST/TemplateBase.cpp
index 76050ceeb35a7..76f96fb8c5dcc 100644
--- a/clang/lib/AST/TemplateBase.cpp
+++ b/clang/lib/AST/TemplateBase.cpp
@@ -56,8 +56,8 @@ static void printIntegral(const TemplateArgument &TemplArg, raw_ostream &Out,
   const llvm::APSInt &Val = TemplArg.getAsIntegral();
 
   if (Policy.UseEnumerators) {
-    if (const EnumType *ET = T->getAs<EnumType>()) {
-      for (const EnumConstantDecl *ECD : ET->getOriginalDecl()->enumerators()) {
+    if (const auto *ED = T->getAsEnumDecl()) {
+      for (const EnumConstantDecl *ECD : ED->enumerators()) {
         // In Sema::CheckTemplateArugment, enum template arguments value are
         // extended to the size of the integer underlying the enum type.  This
         // may create a size difference between the enum value and template
diff --git a/clang/lib/AST/TextNodeDumper.cpp b/clang/lib/AST/TextNodeDumper.cpp
index 085616049373e..9dca5cf088a85 100644
--- a/clang/lib/AST/TextNodeDumper.cpp
+++ b/clang/lib/AST/TextNodeDumper.cpp
@@ -1401,7 +1401,7 @@ static void dumpBasePath(raw_ostream &OS, const CastExpr *Node) {
       OS << " -> ";
 
     const auto *RD = cast<CXXRecordDecl>(
-        Base->getType()->castAs<RecordType>()->getOriginalDecl());
+        Base->getType()->castAsCanonical<RecordType>()->getOriginalDecl());
 
     if (Base->isVirtual())
       OS << "virtual ";
diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp
index a70bc5424009c..b34a0d556b596 100644
--- a/clang/lib/AST/Type.cpp
+++ b/clang/lib/AST/Type.cpp
@@ -113,10 +113,8 @@ const IdentifierInfo *QualType::getBaseTypeIdentifier() const {
     return DNT->getIdentifier();
   if (ty->isPointerOrReferenceType())
     return ty->getPointeeType().getBaseTypeIdentifier();
-  if (ty->isRecordType())
-    ND = ty->castAs<RecordType>()->getOriginalDecl();
-  else if (ty->isEnumeralType())
-    ND = ty->castAs<EnumType>()->getOriginalDecl();
+  if (const auto *TT = ty->getAs<TagType>())
+    ND = TT->getOriginalDecl();
   else if (ty->getTypeClass() == Type::Typedef)
     ND = ty->castAs<TypedefType>()->getDecl();
   else if (ty->isArrayType())
@@ -672,46 +670,42 @@ const Type *Type::getUnqualifiedDesugaredType() const {
 }
 
 bool Type::isClassType() const {
-  if (const auto *RT = getAs<RecordType>())
+  if (const auto *RT = getAsCanonical<RecordType>())
     return RT->getOriginalDecl()->isClass();
   return false;
 }
 
 bool Type::isStructureType() const {
-  if (const auto *RT = getAs<RecordType>())
+  if (const auto *RT = getAsCanonical<RecordType>())
     return RT->getOriginalDecl()->isStruct();
   return false;
 }
 
 bool Type::isStructureTypeWithFlexibleArrayMember() const {
-  const auto *RT = getAs<RecordType>();
+  const auto *RT = getAsCanonical<RecordType>();
   if (!RT)
     return false;
-  const auto *Decl = RT->getOriginalDecl()->getDefinitionOrSelf();
+  const auto *Decl = RT->getOriginalDecl();
   if (!Decl->isStruct())
     return false;
-  return Decl->hasFlexibleArrayMember();
+  return Decl->getDefinitionOrSelf()->hasFlexibleArrayMember();
 }
 
 bool Type::isObjCBoxableRecordType() const {
-  if (const auto *RT = getAs<RecordType>())
-    return RT->getOriginalDecl()
-        ->getDefinitionOrSelf()
-        ->hasAttr<ObjCBoxableAttr>();
+  if (const auto *RD = getAsRecordDecl())
+    return RD->hasAttr<ObjCBoxableAttr>();
   return false;
 }
 
 bool Type::isInterfaceType() const {
-  if (const auto *RT = getAs<RecordType>())
+  if (const auto *RT = getAsCanonical<RecordType>())
     return RT->getOriginalDecl()->isInterface();
   return false;
 }
 
 bool Type::isStructureOrClassType() const {
-  if (const auto *RT = getAs<RecordType>()) {
-    RecordDecl *RD = RT->getOriginalDecl();
-    return RD->isStruct() || RD->isClass() || RD->isInterface();
-  }
+  if (const auto *RT = getAsCanonical<RecordType>())
+    return RT->getOriginalDecl()->isStructureOrClass();
   return false;
 }
 
@@ -722,7 +716,7 @@ bool Type::isVoidPointerType() const {
 }
 
 bool Type::isUnionType() const {
-  if (const auto *RT = getAs<RecordType>())
+  if (const auto *RT = getAsCanonical<RecordType>())
     return RT->getOriginalDecl()->isUnion();
   return false;
 }
@@ -739,7 +733,7 @@ bool Type::isComplexIntegerType() const {
 }
 
 bool Type::isScopedEnumeralType() const {
-  if (const auto *ET = getAs<EnumType>())
+  if (const auto *ET = getAsCanonical<EnumType>())
     return ET->getOriginalDecl()->isScoped();
   return false;
 }
@@ -1920,12 +1914,7 @@ const CXXRecordDecl *Type::getPointeeCXXRecordDecl() const {
     PointeeType = RT->getPointeeType();
   else
     return nullptr;
-
-  if (const auto *RT = PointeeType->getAs<RecordType>())
-    return dyn_cast<CXXRecordDecl>(
-        RT->getOriginalDecl()->getDefinitionOrSelf());
-
-  return nullptr;
+  return PointeeType->getAsCXXRecordDecl();
 }
 
 CXXRecordDecl *Type::getAsCXXRecordDecl() const {
@@ -1938,6 +1927,11 @@ CXXRecordDecl *Type::getAsCXXRecordDecl() const {
   return cast<CXXRecordDecl>(TD)->getDefinitionOrSelf();
 }
 
+CXXRecordDecl *Type::castAsCXXRecordDecl() const {
+  const auto *TT = cast<TagType>(CanonicalType);
+  return cast<CXXRecordDecl>(TT->getOriginalDecl())->getDefinitionOrSelf();
+}
+
 RecordDecl *Type::getAsRecordDecl() const {
   const auto *TT = dyn_cast<TagType>(CanonicalType);
   if (!isa_and_present<RecordType, InjectedClassNameType>(TT))
@@ -1945,12 +1939,33 @@ RecordDecl *Type::getAsRecordDecl() const {
   return cast<RecordDecl>(TT->getOriginalDecl())->getDefinitionOrSelf();
 }
 
+RecordDecl *Type::castAsRecordDecl() const {
+  const auto *TT = cast<TagType>(CanonicalType);
+  return cast<RecordDecl>(TT->getOriginalDecl())->getDefinitionOrSelf();
+}
+
+EnumDecl *Type::getAsEnumDecl() const {
+  if (const auto *TT = dyn_cast<EnumType>(CanonicalType))
+    return TT->getOriginalDecl()->getDefinitionOrSelf();
+  return nullptr;
+}
+
+EnumDecl *Type::castAsEnumDecl() const {
+  return cast<EnumType>(CanonicalType)
+      ->getOriginalDecl()
+      ->getDefinitionOrSelf();
+}
+
 TagDecl *Type::getAsTagDecl() const {
   if (const auto *TT = dyn_cast<TagType>(CanonicalType))
     return TT->getOriginalDecl()->getDefinitionOrSelf();
   return nullptr;
 }
 
+TagDecl *Type::castAsTagDecl() const {
+  return cast<TagType>(CanonicalType)->getOriginalDecl()->getDefinitionOrSelf();
+}
+
 const TemplateSpecializationType *
 Type::getAsNonAliasTemplateSpecializationType() const {
   const auto *TST = getAs<TemplateSpecializationType>();
@@ -2246,10 +2261,9 @@ bool Type::isSignedIntegerType() const {
   if (const auto *BT = dyn_cast<BuiltinType>(CanonicalType))
     return BT->isSignedInteger();
 
-  if (const EnumType *ET = dyn_cast<EnumType>(CanonicalType)) {
+  if (const auto *ED = getAsEnumDecl()) {
     // Incomplete enum types are not treated as integer types.
     // FIXME: In C++, enum types are never integer types.
-    const auto *ED = ET->getOriginalDecl()->getDefinitionOrSelf();
     if (!ED->isComplete() || ED->isScoped())
       return false;
     return ED->getIntegerType()->isSignedIntegerType();
@@ -2267,8 +2281,7 @@ bool Type::isSignedIntegerOrEnumerationType() const {
   if (const auto *BT = dyn_cast<BuiltinType>(CanonicalType))
     return BT->isSignedInteger();
 
-  if (const auto *ET = dyn_cast<EnumType>(CanonicalType)) {
-    const auto *ED = ET->getOriginalDecl()->getDefinitionOrSelf();
+  if (const auto *ED = getAsEnumDecl()) {
     if (!ED->isComplete())
       return false;
     return ED->getIntegerType()->isSignedIntegerType();
@@ -2296,10 +2309,9 @@ bool Type::isUnsignedIntegerType() const {
   if (const auto *BT = dyn_cast<BuiltinType>(CanonicalType))
     return BT->isUnsignedInteger();
 
-  if (const auto *ET = dyn_cast<EnumType>(CanonicalType)) {
+  if (const auto *ED = getAsEnumDecl()) {
     // Incomplete enum types are not treated as integer types.
     // FIXME: In C++, enum types are never integer types.
-    const auto *ED = ET->getOriginalDecl()->getDefinitionOrSelf();
     if (!ED->isComplete() || ED->isScoped())
       return false;
     return ED->getIntegerType()->isUnsignedIntegerType();
@@ -2317,8 +2329,7 @@ bool Type::isUnsignedIntegerOrEnumerationType() const {
   if (const auto *BT = dyn_cast<BuiltinType>(CanonicalType))
     return BT->isUnsignedInteger();
 
-  if (const auto *ET = dyn_cast<EnumType>(CanonicalType)) {
-    const auto *ED = ET->getOriginalDecl()->getDefinitionOrSelf();
+  if (const auto *ED = getAsEnumDecl()) {
     if (!ED->isComplete())
       return false;
     return ED->getIntegerType()->isUnsignedIntegerType();
@@ -2398,10 +2409,8 @@ bool Type::isArithmeticType() const {
 bool Type::hasBooleanRepresentation() const {
   if (const auto *VT = dyn_cast<VectorType>(CanonicalType))
     return VT->getElementType()->isBooleanType();
-  if (const auto *ET = dyn_cast<EnumType>(CanonicalType)) {
-    const auto *ED = ET->getOriginalDecl()->getDefinitionOrSelf();
+  if (const auto *ED = getAsEnumDecl())
     return ED->isComplete() && ED->getIntegerType()->isBooleanType();
-  }
   if (const auto *IT = dyn_cast<BitIntType>(CanonicalType))
     return IT->getNumBits() == 1;
   return isBooleanType();
@@ -2432,10 +2441,7 @@ Type::ScalarTypeKind Type::getScalarTypeKind() const {
   } else if (isa<MemberPointerType>(T)) {
     return STK_MemberPointer;
   } else if (isa<EnumType>(T)) {
-    assert(cast<EnumType>(T)
-               ->getOriginalDecl()
-               ->getDefinitionOrSelf()
-               ->isComplete());
+    assert(T->getAsEnumDecl()->isComplete());
     return STK_Integral;
   } else if (const auto *CT = dyn_cast<ComplexType>(T)) {
     if (CT->getElementType()->isRealFloatingType())
@@ -2459,8 +2465,8 @@ Type::ScalarTypeKind Type::getScalarTypeKind() const {
 /// includes union types.
 bool Type::isAggregateType() const {
   if (const auto *Record = dyn_cast<RecordType>(CanonicalType)) {
-    if (const auto *ClassDecl = dyn_cast<CXXRecordDecl>(
-            Record->getOriginalDecl()->getDefinitionOrSelf()))
+    if (const auto *ClassDecl =
+            dyn_cast<CXXRecordDecl>(Record->getOriginalDecl()))
       return ClassDecl->isAggregate();
 
     return true;
@@ -2494,8 +2500,7 @@ bool Type::isIncompleteType(NamedDecl **Def) const {
     // be completed.
     return isVoidType();
   case Enum: {
-    EnumDecl *EnumD =
-        cast<EnumType>(CanonicalType)->getOriginalDecl()->getDefinitionOrSelf();
+    auto *EnumD = castAsEnumDecl();
     if (Def)
       *Def = EnumD;
     return !EnumD->isComplete();
@@ -2503,17 +2508,13 @@ bool Type::isIncompleteType(NamedDecl **Def) const {
   case Record: {
     // A tagged type (struct/union/enum/class) is incomplete if the decl is a
     // forward declaration, but not a full definition (C99 6.2.5p22).
-    RecordDecl *Rec = cast<RecordType>(CanonicalType)
-                          ->getOriginalDecl()
-                          ->getDefinitionOrSelf();
+    auto *Rec = castAsRecordDecl();
     if (Def)
       *Def = Rec;
     return !Rec->isCompleteDefinition();
   }
   case InjectedClassName: {
-    CXXRecordDecl *Rec = cast<InjectedClassNameType>(CanonicalType)
-                             ->getOriginalDecl()
-                             ->getDefinitionOrSelf();
+    auto *Rec = castAsCXXRecordDecl();
     if (!Rec->isBeingDefined())
       return false;
     if (Def)
@@ -2573,7 +2574,7 @@ bool Type::isAlwaysIncompleteType() const {
 
   // Forward declarations of structs, classes, enums, and unions could be later
   // completed in a compilation unit by providing a type definition.
-  if (getAsTagDecl())
+  if (isa<TagType>(CanonicalType))
     return false;
 
   // Other types are incompletable.
@@ -2797,7 +2798,7 @@ bool QualType::isCXX98PODType(const ASTContext &Context) const {
   case Type::Record:
     if (const auto *ClassDecl = dyn_cast<CXXRecordDecl>(
             cast<RecordType>(CanonicalType)->getOriginalDecl()))
-      return ClassDecl->getDefinitionOrSelf()->isPOD();
+      return ClassDecl->isPOD();
 
     // C struct/union is POD.
     return true;
@@ -2837,23 +2838,22 @@ bool QualType::isTrivialType(const ASTContext &Context) const {
   // As an extension, Clang treats vector types as Scalar types.
   if (CanonicalType->isScalarType() || CanonicalType->isVectorType())
     return true;
-  if (const auto *RT = CanonicalType->getAs<RecordType>()) {
-    if (const auto *ClassDecl =
-            dyn_cast<CXXRecordDecl>(RT->getOriginalDecl())) {
-      // C++20 [class]p6:
-      //   A trivial class is a class that is trivially copyable, and
-      //     has one or more eligible default constructors such that each is
-      //     trivial.
-      // FIXME: We should merge this definition of triviality into
-      // CXXRecordDecl::isTrivial. Currently it computes the wrong thing.
-      return ClassDecl->hasTrivialDefaultConstructor() &&
-             !ClassDecl->hasNonTrivialDefaultConstructor() &&
-             ClassDecl->isTriviallyCopyable();
-    }
 
-    return true;
+  if (const auto *ClassDecl = CanonicalType->getAsCXXRecordDecl()) {
+    // C++20 [class]p6:
+    //   A trivial class is a class that is trivially copyable, and
+    //     has one or more eligible default constructors such that each is
+    //     trivial.
+    // FIXME: We should merge this definition of triviality into
+    // CXXRecordDecl::isTrivial. Currently it computes the wrong thing.
+    return ClassDecl->hasTrivialDefaultConstructor() &&
+           !ClassDecl->hasNonTrivialDefaultConstructor() &&
+           ClassDecl->isTriviallyCopyable();
   }
 
+  if (isa<RecordType>(CanonicalType))
+    return true;
+
   // No other types can match.
   return false;
 }
@@ -2897,18 +2897,13 @@ static bool isTriviallyCopyableTypeImpl(const QualType &type,
   if (CanonicalType->isMFloat8Type())
     return true;
 
-  if (const auto *RT = CanonicalType->getAs<RecordType>()) {
-    if (const auto *ClassDecl =
-            dyn_cast<CXXRecordDecl>(RT->getOriginalDecl())) {
-      if (IsCopyConstructible) {
+  if (const auto *RD = CanonicalType->getAsRecordDecl()) {
+    if (const auto *ClassDecl = dyn_cast<CXXRecordDecl>(RD)) {
+      if (IsCopyConstructible)
         return ClassDecl->isTriviallyCopyConstructible();
-      } else {
-        return ClassDecl->isTriviallyCopyable();
-      }
+      return ClassDecl->isTriviallyCopyable();
     }
-    return !RT->getOriginalDecl()
-                ->getDefinitionOrSelf()
-                ->isNonTrivialToPrimitiveCopy();
+    return !RD->isNonTrivialToPrimitiveCopy();
   }
   // No other types can match.
   return false;
@@ -2996,11 +2991,9 @@ bool QualType::isWebAssemblyFuncrefType() const {
 
 QualType::PrimitiveDefaultInitializeKind
 QualType::isNonTrivialToPrimitiveDefaultInitialize() const {
-  if (const auto *RT =
-          getTypePtr()->getBaseElementTypeUnsafe()->getAs<RecordType>())
-    if (RT->getOriginalDecl()
-            ->getDefinitionOrSelf()
-            ->isNonTrivialToPrimitiveDefaultInitialize())
+  if (const auto *RD =
+          getTypePtr()->getBaseElementTypeUnsafe()->getAsRecordDecl())
+    if (RD->isNonTrivialToPrimitiveDefaultInitialize())
       return PDIK_Struct;
 
   switch (getQualifiers().getObjCLifetime()) {
@@ -3014,11 +3007,9 @@ QualType::isNonTrivialToPrimitiveDefaultInitialize() const {
 }
 
 QualType::PrimitiveCopyKind QualType::isNonTrivialToPrimitiveCopy() const {
-  if (const auto *RT =
-          getTypePtr()->getBaseElementTypeUnsafe()->getAs<RecordType>())
-    if (RT->getOriginalDecl()
-            ->getDefinitionOrSelf()
-            ->isNonTrivialToPrimitiveCopy())
+  if (const auto *RD =
+          getTypePtr()->getBaseElementTypeUnsafe()->getAsRecordDecl())
+    if (RD->isNonTrivialToPrimitiveCopy())
       return PCK_Struct;
 
   Qualifiers Qs = getQualifiers();
@@ -3075,7 +3066,7 @@ bool Type::isLiteralType(const ASTContext &Ctx) const {
   if (BaseTy->isReferenceType())
     return true;
   //    -- a class type that has all of the following properties:
-  if (const auto *RT = BaseTy->getAs<RecordType>()) {
+  if (const auto *RD = BaseTy->getAsRecordDecl()) {
     //    -- a trivial destructor,
     //    -- every constructor call and full-expression in the
     //       brace-or-equal-initializers for non-static data members (if any)
@@ -3086,8 +3077,8 @@ bool Type::isLiteralType(const ASTContext &Ctx) const {
     //    -- all non-static data members and base classes of literal types
     //
     // We resolve DR1361 by ignoring the second bullet.
-    if (const auto *ClassDecl = dyn_cast<CXXRecordDecl>(RT->getOriginalDecl()))
-      return ClassDecl->getDefinitionOrSelf()->isLiteral();
+    if (const auto *ClassDecl = dyn_cast<CXXRecordDecl>(RD))
+      return ClassDecl->isLiteral();
 
     return true;
   }
@@ -3139,10 +3130,10 @@ bool Type::isStandardLayoutType() const {
   // As an extension, Clang treats vector types as Scalar types.
   if (BaseTy->isScalarType() || BaseTy->isVectorType())
     return true;
-  if (const auto *RT = BaseTy->getAs<RecordType>()) {
-    if (const auto *ClassDecl = dyn_cast<CXXRecordDecl>(RT->getOriginalDecl()))
-      if (!ClassDecl->getDefinitionOrSelf()->isStandardLayout())
-        return false;
+  if (const auto *RD = BaseTy->getAsRecordDecl()) {
+    if (const auto *ClassDecl = dyn_cast<CXXRecordDecl>(RD);
+        ClassDecl && !ClassDecl->isStandardLayout())
+      return false;
 
     // Default to 'true' for non-C++ class types.
     // FIXME: This is a bit dubious, but plain C structs should trivially meet
@@ -3182,10 +3173,8 @@ bool QualType::isCXX11PODType(const ASTContext &Context) const {
   // As an extension, Clang treats vector types as Scalar types.
   if (BaseTy->isScalarType() || BaseTy->isVectorType())
     return true;
-  if (const auto *RT = BaseTy->getAs<RecordType>()) {
-    if (const auto *ClassDecl =
-            dyn_cast<CXXRecordDecl>(RT->getOriginalDecl())) {
-      ClassDecl = ClassDecl->getDefinitionOrSelf();
+  if (const auto *RD = BaseTy->getAsRecordDecl()) {
+    if (const auto *ClassDecl = dyn_cast<CXXRecordDecl>(RD)) {
       // C++11 [class]p10:
       //   A POD struct is a non-union class that is both a trivial class [...]
       if (!ClassDecl->isTrivial())
@@ -3224,7 +3213,7 @@ bool Type::isNothrowT() const {
 }
 
 bool Type::isAlignValT() const {
-  if (const auto *ET = getAs<EnumType>()) {
+  if (const auto *ET = getAsCanonical<EnumType>()) {
     const auto *ED = ET->getOriginalDecl();
     IdentifierInfo *II = ED->getIdentifier();
     if (II && II->isStr("align_val_t") && ED->isInStdNamespace())
@@ -3234,7 +3223,7 @@ bool Type::isAlignValT() const {
 }
 
 bool Type::isStdByteType() const {
-  if (const auto *ET = getAs<EnumType>()) {
+  if (const auto *ET = getAsCanonical<EnumType>()) {
     const auto *ED = ET->getOriginalDecl();
     IdentifierInfo *II = ED->getIdentifier();
     if (II && II->isStr("byte") && ED->isInStdNamespace())
@@ -4405,7 +4394,7 @@ bool RecordType::hasConstFields() const {
       if (FieldTy.isConstQualified())
         return true;
       FieldTy = FieldTy.getCanonicalType();
-      if (const auto *FieldRecTy = FieldTy->getAs<RecordType>()) {
+      if (const auto *FieldRecTy = FieldTy->getAsCanonical<RecordType>()) {
         if (!llvm::is_contained(RecordTypeList, FieldRecTy))
           RecordTypeList.push_back(FieldRecTy);
       }
@@ -5419,7 +5408,7 @@ bool Type::isCARCBridgableType() const {
 
 /// Check if the specified type is the CUDA device builtin surface type.
 bool Type::isCUDADeviceBuiltinSurfaceType() const {
-  if (const auto *RT = getAs<RecordType>())
+  if (const auto *RT = getAsCanonical<RecordType>())
     return RT->getOriginalDecl()
         ->getMostRecentDecl()
         ->hasAttr<CUDADeviceBuiltinSurfaceTypeAttr>();
@@ -5428,7 +5417,7 @@ bool Type::isCUDADeviceBuiltinSurfaceType() const {
 
 /// Check if the specified type is the CUDA device builtin texture type.
 bool Type::isCUDADeviceBuiltinTextureType() const {
-  if (const auto *RT = getAs<RecordType>())
+  if (const auto *RT = getAsCanonical<RecordType>())
     return RT->getOriginalDecl()
         ->getMostRecentDecl()
         ->hasAttr<CUDADeviceBuiltinTextureTypeAttr>();
@@ -5503,8 +5492,7 @@ QualType::DestructionKind QualType::isDestructedTypeImpl(QualType type) {
     return DK_objc_weak_lifetime;
   }
 
-  if (const auto *RT = type->getBaseElementTypeUnsafe()->getAs<RecordType>()) {
-    const RecordDecl *RD = RT->getOriginalDecl();
+  if (const auto *RD = type->getBaseElementTypeUnsafe()->getAsRecordDecl()) {
     if (const auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
       /// Check if this is a C++ object with a non-trivial destructor.
       if (CXXRD->hasDefinition() && !CXXRD->hasTrivialDestructor())
@@ -5512,7 +5500,7 @@ QualType::DestructionKind QualType::isDestructedTypeImpl(QualType type) {
     } else {
       /// Check if this is a C struct that is non-trivial to destroy or an array
       /// that contains such a struct.
-      if (RD->getDefinitionOrSelf()->isNonTrivialToPrimitiveDestroy())
+      if (RD->isNonTrivialToPrimitiveDestroy())
         return DK_nontrivial_c_struct;
     }
   }
diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp
index 52ff0d5b5771b..54ca42d2035ad 100644
--- a/clang/lib/AST/TypePrinter.cpp
+++ b/clang/lib/AST/TypePrinter.cpp
@@ -2382,7 +2382,7 @@ static bool isSubstitutedType(ASTContext &Ctx, QualType T, QualType Pattern,
     return true;
 
   // A type parameter matches its argument.
-  if (auto *TTPT = Pattern->getAs<TemplateTypeParmType>()) {
+  if (auto *TTPT = Pattern->getAsCanonical<TemplateTypeParmType>()) {
     if (TTPT->getDepth() == Depth && TTPT->getIndex() < Args.size() &&
         Args[TTPT->getIndex()].getKind() == TemplateArgument::Type) {
       QualType SubstArg = Ctx.getQualifiedType(
diff --git a/clang/lib/AST/VTTBuilder.cpp b/clang/lib/AST/VTTBuilder.cpp
index 85101aee97e66..89b58b557ddca 100644
--- a/clang/lib/AST/VTTBuilder.cpp
+++ b/clang/lib/AST/VTTBuilder.cpp
@@ -63,11 +63,7 @@ void VTTBuilder::LayoutSecondaryVTTs(BaseSubobject Base) {
     if (I.isVirtual())
         continue;
 
-    const auto *BaseDecl =
-        cast<CXXRecordDecl>(
-            I.getType()->castAs<RecordType>()->getOriginalDecl())
-            ->getDefinitionOrSelf();
-
+    const auto *BaseDecl = I.getType()->castAsCXXRecordDecl();
     const ASTRecordLayout &Layout = Ctx.getASTRecordLayout(RD);
     CharUnits BaseOffset = Base.getBaseOffset() +
       Layout.getBaseClassOffset(BaseDecl);
@@ -91,10 +87,7 @@ VTTBuilder::LayoutSecondaryVirtualPointers(BaseSubobject Base,
     return;
 
   for (const auto &I : RD->bases()) {
-    const auto *BaseDecl =
-        cast<CXXRecordDecl>(
-            I.getType()->castAs<RecordType>()->getOriginalDecl())
-            ->getDefinitionOrSelf();
+    const auto *BaseDecl = I.getType()->castAsCXXRecordDecl();
 
     // Itanium C++ ABI 2.6.2:
     //   Secondary virtual pointers are present for all bases with either
@@ -157,10 +150,7 @@ VTTBuilder::LayoutSecondaryVirtualPointers(BaseSubobject Base,
 void VTTBuilder::LayoutVirtualVTTs(const CXXRecordDecl *RD,
                                    VisitedVirtualBasesSetTy &VBases) {
   for (const auto &I : RD->bases()) {
-    const auto *BaseDecl =
-        cast<CXXRecordDecl>(
-            I.getType()->castAs<RecordType>()->getOriginalDecl())
-            ->getDefinitionOrSelf();
+    const auto *BaseDecl = I.getType()->castAsCXXRecordDecl();
 
     // Check if this is a virtual base.
     if (I.isVirtual()) {
diff --git a/clang/lib/Analysis/ExprMutationAnalyzer.cpp b/clang/lib/Analysis/ExprMutationAnalyzer.cpp
index 823d7543f085f..3fcd3481c2d6b 100644
--- a/clang/lib/Analysis/ExprMutationAnalyzer.cpp
+++ b/clang/lib/Analysis/ExprMutationAnalyzer.cpp
@@ -703,7 +703,8 @@ ExprMutationAnalyzer::Analyzer::findFunctionArgMutation(const Expr *Exp) {
     // definition and see whether the param is mutated inside.
     if (const auto *RefType = ParmType->getAs<RValueReferenceType>()) {
       if (!RefType->getPointeeType().getQualifiers() &&
-          RefType->getPointeeType()->getAs<TemplateTypeParmType>()) {
+          isa<TemplateTypeParmType>(
+              RefType->getPointeeType().getCanonicalType())) {
         FunctionParmMutationAnalyzer *Analyzer =
             FunctionParmMutationAnalyzer::getFunctionParmMutationAnalyzer(
                 *Func, Context, Memorized);
diff --git a/clang/lib/Analysis/ThreadSafetyCommon.cpp b/clang/lib/Analysis/ThreadSafetyCommon.cpp
index f560dd8ae1dd1..68c27eeb7e8e6 100644
--- a/clang/lib/Analysis/ThreadSafetyCommon.cpp
+++ b/clang/lib/Analysis/ThreadSafetyCommon.cpp
@@ -83,13 +83,11 @@ static std::pair<StringRef, bool> classifyCapability(QualType QT) {
   // We need to look at the declaration of the type of the value to determine
   // which it is. The type should either be a record or a typedef, or a pointer
   // or reference thereof.
-  if (const auto *RT = QT->getAs<RecordType>()) {
-    if (const auto *RD = RT->getOriginalDecl())
-      return classifyCapability(*RD->getDefinitionOrSelf());
-  } else if (const auto *TT = QT->getAs<TypedefType>()) {
-    if (const auto *TD = TT->getDecl())
-      return classifyCapability(*TD);
-  } else if (QT->isPointerOrReferenceType())
+  if (const auto *RD = QT->getAsRecordDecl())
+    return classifyCapability(*RD);
+  if (const auto *TT = QT->getAs<TypedefType>())
+    return classifyCapability(*TT->getDecl());
+  if (QT->isPointerOrReferenceType())
     return classifyCapability(QT->getPointeeType());
 
   return ClassifyCapabilityFallback;
diff --git a/clang/lib/CIR/CodeGen/CIRGenCall.cpp b/clang/lib/CIR/CodeGen/CIRGenCall.cpp
index 8a15e5f96aea2..25859885296fa 100644
--- a/clang/lib/CIR/CodeGen/CIRGenCall.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenCall.cpp
@@ -287,10 +287,7 @@ void CIRGenFunction::emitDelegateCallArg(CallArgList &args,
   // Deactivate the cleanup for the callee-destructed param that was pushed.
   assert(!cir::MissingFeatures::thunks());
   if (type->isRecordType() &&
-      type->castAs<RecordType>()
-          ->getOriginalDecl()
-          ->getDefinitionOrSelf()
-          ->isParamDestroyedInCallee() &&
+      type->castAsRecordDecl()->isParamDestroyedInCallee() &&
       param->needsDestruction(getContext())) {
     cgm.errorNYI(param->getSourceRange(),
                  "emitDelegateCallArg: callee-destructed param");
@@ -690,10 +687,8 @@ void CIRGenFunction::emitCallArg(CallArgList &args, const clang::Expr *e,
   // In the Microsoft C++ ABI, aggregate arguments are destructed by the callee.
   // However, we still have to push an EH-only cleanup in case we unwind before
   // we make it to the call.
-  if (argType->isRecordType() && argType->castAs<RecordType>()
-                                     ->getOriginalDecl()
-                                     ->getDefinitionOrSelf()
-                                     ->isParamDestroyedInCallee()) {
+  if (argType->isRecordType() &&
+      argType->castAsRecordDecl()->isParamDestroyedInCallee()) {
     assert(!cir::MissingFeatures::msabi());
     cgm.errorNYI(e->getSourceRange(), "emitCallArg: msabi is NYI");
   }
diff --git a/clang/lib/CIR/CodeGen/CIRGenClass.cpp b/clang/lib/CIR/CodeGen/CIRGenClass.cpp
index 3e5dc22426d8e..722027aa5631d 100644
--- a/clang/lib/CIR/CodeGen/CIRGenClass.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenClass.cpp
@@ -120,9 +120,7 @@ static void emitMemberInitializer(CIRGenFunction &cgf,
 
 static bool isInitializerOfDynamicClass(const CXXCtorInitializer *baseInit) {
   const Type *baseType = baseInit->getBaseClass();
-  const auto *baseClassDecl =
-      cast<CXXRecordDecl>(baseType->castAs<RecordType>()->getOriginalDecl())
-          ->getDefinitionOrSelf();
+  const auto *baseClassDecl = baseType->castAsCXXRecordDecl();
   return baseClassDecl->isDynamicClass();
 }
 
@@ -161,9 +159,7 @@ void CIRGenFunction::emitBaseInitializer(mlir::Location loc,
   Address thisPtr = loadCXXThisAddress();
 
   const Type *baseType = baseInit->getBaseClass();
-  const auto *baseClassDecl =
-      cast<CXXRecordDecl>(baseType->castAs<RecordType>()->getOriginalDecl())
-          ->getDefinitionOrSelf();
+  const auto *baseClassDecl = baseType->castAsCXXRecordDecl();
 
   bool isBaseVirtual = baseInit->isBaseVirtual();
 
@@ -568,9 +564,7 @@ void CIRGenFunction::emitImplicitAssignmentOperatorBody(FunctionArgList &args) {
 
 void CIRGenFunction::destroyCXXObject(CIRGenFunction &cgf, Address addr,
                                       QualType type) {
-  const RecordType *rtype = type->castAs<RecordType>();
-  const CXXRecordDecl *record =
-      cast<CXXRecordDecl>(rtype->getOriginalDecl())->getDefinitionOrSelf();
+  const auto *record = type->castAsCXXRecordDecl();
   const CXXDestructorDecl *dtor = record->getDestructor();
   // TODO(cir): Unlike traditional codegen, CIRGen should actually emit trivial
   // dtors which shall be removed on later CIR passes. However, only remove this
diff --git a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp
index 2b74b27100402..1afac6dd52c2d 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp
@@ -1092,11 +1092,7 @@ LValue CIRGenFunction::emitCastLValue(const CastExpr *e) {
 
   case CK_UncheckedDerivedToBase:
   case CK_DerivedToBase: {
-    const auto *derivedClassTy =
-        e->getSubExpr()->getType()->castAs<clang::RecordType>();
-    auto *derivedClassDecl =
-        cast<CXXRecordDecl>(derivedClassTy->getOriginalDecl())
-            ->getDefinitionOrSelf();
+    auto *derivedClassDecl = e->getSubExpr()->getType()->castAsCXXRecordDecl();
 
     LValue lv = emitLValue(e->getSubExpr());
     Address thisAddr = lv.getAddress();
@@ -1265,17 +1261,11 @@ static void pushTemporaryCleanup(CIRGenFunction &cgf,
   case SD_Static:
   case SD_Thread: {
     CXXDestructorDecl *referenceTemporaryDtor = nullptr;
-    if (const clang::RecordType *rt = e->getType()
-                                          ->getBaseElementTypeUnsafe()
-                                          ->getAs<clang::RecordType>()) {
+    if (const auto *classDecl =
+            e->getType()->getBaseElementTypeUnsafe()->getAsCXXRecordDecl();
+        classDecl && !classDecl->hasTrivialDestructor())
       // Get the destructor for the reference temporary.
-      if (const auto *classDecl = dyn_cast<CXXRecordDecl>(
-              rt->getOriginalDecl()->getDefinitionOrSelf())) {
-        if (!classDecl->hasTrivialDestructor())
-          referenceTemporaryDtor =
-              classDecl->getDefinitionOrSelf()->getDestructor();
-      }
-    }
+      referenceTemporaryDtor = classDecl->getDestructor();
 
     if (!referenceTemporaryDtor)
       return;
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp b/clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp
index 2a3c98c458256..bab9ac73d7e65 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp
@@ -409,10 +409,7 @@ void AggExprEmitter::visitCXXParenListOrInitListExpr(
   // the disadvantage is that the generated code is more difficult for
   // the optimizer, especially with bitfields.
   unsigned numInitElements = args.size();
-  RecordDecl *record = e->getType()
-                           ->castAs<RecordType>()
-                           ->getOriginalDecl()
-                           ->getDefinitionOrSelf();
+  auto *record = e->getType()->castAsRecordDecl();
 
   // We'll need to enter cleanup scopes in case any of the element
   // initializers throws an exception.
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp b/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp
index ced3c2273e865..262d2548d5c39 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprConstant.cpp
@@ -680,9 +680,7 @@ mlir::Attribute ConstantEmitter::tryEmitPrivateForVarInit(const VarDecl &d) {
         // assignments and whatnots). Since this is for globals shouldn't
         // be a problem for the near future.
         if (cd->isTrivial() && cd->isDefaultConstructor()) {
-          const auto *cxxrd =
-              cast<CXXRecordDecl>(ty->getAs<RecordType>()->getOriginalDecl())
-                  ->getDefinitionOrSelf();
+          const auto *cxxrd = ty->castAsCXXRecordDecl();
           if (cxxrd->getNumBases() != 0) {
             // There may not be anything additional to do here, but this will
             // force us to pause and test this path when it is supported.
@@ -949,7 +947,7 @@ mlir::Value CIRGenModule::emitNullConstant(QualType t, mlir::Location loc) {
     errorNYI("CIRGenModule::emitNullConstant ConstantArrayType");
   }
 
-  if (t->getAs<RecordType>())
+  if (t->isRecordType())
     errorNYI("CIRGenModule::emitNullConstant RecordType");
 
   assert(t->isMemberDataPointerType() &&
diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.cpp b/clang/lib/CIR/CodeGen/CIRGenFunction.cpp
index 706c14ae962a8..deabb94b7d129 100644
--- a/clang/lib/CIR/CodeGen/CIRGenFunction.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenFunction.cpp
@@ -356,12 +356,8 @@ static bool mayDropFunctionReturn(const ASTContext &astContext,
                                   QualType returnType) {
   // We can't just discard the return value for a record type with a complex
   // destructor or a non-trivially copyable type.
-  if (const RecordType *recordType =
-          returnType.getCanonicalType()->getAs<RecordType>()) {
-    if (const auto *classDecl = dyn_cast<CXXRecordDecl>(
-            recordType->getOriginalDecl()->getDefinitionOrSelf()))
-      return classDecl->hasTrivialDestructor();
-  }
+  if (const auto *classDecl = returnType->getAsCXXRecordDecl())
+    return classDecl->hasTrivialDestructor();
   return returnType.isTriviallyCopyableType(astContext);
 }
 
@@ -829,14 +825,9 @@ std::string CIRGenFunction::getCounterAggTmpAsString() {
 void CIRGenFunction::emitNullInitialization(mlir::Location loc, Address destPtr,
                                             QualType ty) {
   // Ignore empty classes in C++.
-  if (getLangOpts().CPlusPlus) {
-    if (const RecordType *rt = ty->getAs<RecordType>()) {
-      if (cast<CXXRecordDecl>(rt->getOriginalDecl())
-              ->getDefinitionOrSelf()
-              ->isEmpty())
-        return;
-    }
-  }
+  if (getLangOpts().CPlusPlus)
+    if (const auto *rd = ty->getAsCXXRecordDecl(); rd && rd->isEmpty())
+      return;
 
   // Cast the dest ptr to the appropriate i8 pointer type.
   if (builder.isInt8Ty(destPtr.getElementType())) {
diff --git a/clang/lib/CIR/CodeGen/CIRGenModule.cpp b/clang/lib/CIR/CodeGen/CIRGenModule.cpp
index 46bca51767c02..870fd973e6e9a 100644
--- a/clang/lib/CIR/CodeGen/CIRGenModule.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenModule.cpp
@@ -1074,8 +1074,7 @@ static bool isVarDeclStrongDefinition(const ASTContext &astContext,
     if (astContext.isAlignmentRequired(varType))
       return true;
 
-    if (const auto *rt = varType->getAs<RecordType>()) {
-      const RecordDecl *rd = rt->getOriginalDecl()->getDefinitionOrSelf();
+    if (const auto *rd = varType->getAsRecordDecl()) {
       for (const FieldDecl *fd : rd->fields()) {
         if (fd->isBitField())
           continue;
@@ -2178,10 +2177,7 @@ CharUnits CIRGenModule::computeNonVirtualBaseClassOffset(
     // Get the layout.
     const ASTRecordLayout &layout = astContext.getASTRecordLayout(rd);
 
-    const auto *baseDecl =
-        cast<CXXRecordDecl>(
-            base->getType()->castAs<clang::RecordType>()->getOriginalDecl())
-            ->getDefinitionOrSelf();
+    const auto *baseDecl = base->getType()->castAsCXXRecordDecl();
 
     // Add the offset.
     offset += layout.getBaseClassOffset(baseDecl);
diff --git a/clang/lib/CIR/CodeGen/CIRGenOpenACCClause.cpp b/clang/lib/CIR/CodeGen/CIRGenOpenACCClause.cpp
index 4ad15e6922ddd..5e1bc8631cac8 100644
--- a/clang/lib/CIR/CodeGen/CIRGenOpenACCClause.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenOpenACCClause.cpp
@@ -506,8 +506,9 @@ class OpenACCClauseCIREmitter final
                              const VarDecl *varRecipe, const VarDecl *temporary,
                              DeclContext *dc, QualType baseType,
                              mlir::Value mainOp) {
-    mlir::ModuleOp mod =
-        builder.getBlock()->getParent()->getParentOfType<mlir::ModuleOp>();
+    mlir::ModuleOp mod = builder.getBlock()
+                             ->getParent()
+                             ->template getParentOfType<mlir::ModuleOp>();
 
     std::string recipeName =
         getRecipeName<RecipeTy>(varRef->getSourceRange(), baseType);
diff --git a/clang/lib/CIR/CodeGen/CIRGenTypes.cpp b/clang/lib/CIR/CodeGen/CIRGenTypes.cpp
index aada6094d0fd5..bb24933a22ed7 100644
--- a/clang/lib/CIR/CodeGen/CIRGenTypes.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenTypes.cpp
@@ -185,9 +185,8 @@ isSafeToConvert(QualType qt, CIRGenTypes &cgt,
     qt = at->getValueType();
 
   // If this is a record, check it.
-  if (const auto *rt = qt->getAs<RecordType>())
-    return isSafeToConvert(rt->getOriginalDecl()->getDefinitionOrSelf(), cgt,
-                           alreadyChecked);
+  if (const auto *rd = qt->getAsRecordDecl())
+    return isSafeToConvert(rd, cgt, alreadyChecked);
 
   // If this is an array, check the elements, which are embedded inline.
   if (const auto *at = cgt.getASTContext().getAsArrayType(qt))
@@ -247,10 +246,7 @@ mlir::Type CIRGenTypes::convertRecordDeclType(const clang::RecordDecl *rd) {
     for (const auto &base : cxxRecordDecl->bases()) {
       if (base.isVirtual())
         continue;
-      convertRecordDeclType(base.getType()
-                                ->castAs<RecordType>()
-                                ->getOriginalDecl()
-                                ->getDefinitionOrSelf());
+      convertRecordDeclType(base.getType()->castAsRecordDecl());
     }
   }
 
@@ -466,8 +462,7 @@ mlir::Type CIRGenTypes::convertType(QualType type) {
   }
 
   case Type::Enum: {
-    const EnumDecl *ed =
-        cast<EnumType>(ty)->getOriginalDecl()->getDefinitionOrSelf();
+    const auto *ed = ty->castAsEnumDecl();
     if (auto integerType = ed->getIntegerType(); !integerType.isNull())
       return convertType(integerType);
     // Return a placeholder 'i32' type.  This can be changed later when the
@@ -571,10 +566,8 @@ bool CIRGenTypes::isZeroInitializable(clang::QualType t) {
         return true;
   }
 
-  if (const RecordType *rt = t->getAs<RecordType>()) {
-    const RecordDecl *rd = rt->getOriginalDecl()->getDefinitionOrSelf();
+  if (const auto *rd = t->getAsRecordDecl())
     return isZeroInitializable(rd);
-  }
 
   if (t->getAs<MemberPointerType>()) {
     cgm.errorNYI(SourceLocation(), "isZeroInitializable for MemberPointerType",
diff --git a/clang/lib/CIR/CodeGen/TargetInfo.cpp b/clang/lib/CIR/CodeGen/TargetInfo.cpp
index 7b6259b04122d..62a8c59abe604 100644
--- a/clang/lib/CIR/CodeGen/TargetInfo.cpp
+++ b/clang/lib/CIR/CodeGen/TargetInfo.cpp
@@ -6,12 +6,10 @@ using namespace clang::CIRGen;
 
 bool clang::CIRGen::isEmptyRecordForLayout(const ASTContext &context,
                                            QualType t) {
-  const RecordType *rt = t->getAs<RecordType>();
-  if (!rt)
+  const auto *rd = t->getAsRecordDecl();
+  if (!rd)
     return false;
 
-  const RecordDecl *rd = rt->getOriginalDecl()->getDefinitionOrSelf();
-
   // If this is a C++ record, check the bases first.
   if (const CXXRecordDecl *cxxrd = dyn_cast<CXXRecordDecl>(rd)) {
     if (cxxrd->isDynamicClass())
diff --git a/clang/lib/CodeGen/ABIInfo.cpp b/clang/lib/CodeGen/ABIInfo.cpp
index 83cc6377d502c..acd678193b5a8 100644
--- a/clang/lib/CodeGen/ABIInfo.cpp
+++ b/clang/lib/CodeGen/ABIInfo.cpp
@@ -67,8 +67,7 @@ bool ABIInfo::isHomogeneousAggregate(QualType Ty, const Type *&Base,
     if (!isHomogeneousAggregate(AT->getElementType(), Base, Members))
       return false;
     Members *= NElements;
-  } else if (const RecordType *RT = Ty->getAs<RecordType>()) {
-    const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
+  } else if (const auto *RD = Ty->getAsRecordDecl()) {
     if (RD->hasFlexibleArrayMember())
       return false;
 
diff --git a/clang/lib/CodeGen/ABIInfoImpl.cpp b/clang/lib/CodeGen/ABIInfoImpl.cpp
index 79dbe70a0c8eb..13c837a0fb680 100644
--- a/clang/lib/CodeGen/ABIInfoImpl.cpp
+++ b/clang/lib/CodeGen/ABIInfoImpl.cpp
@@ -28,8 +28,8 @@ ABIArgInfo DefaultABIInfo::classifyArgumentType(QualType Ty) const {
   }
 
   // Treat an enum type as its underlying type.
-  if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-    Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = Ty->getAsEnumDecl())
+    Ty = ED->getIntegerType();
 
   ASTContext &Context = getContext();
   if (const auto *EIT = Ty->getAs<BitIntType>())
@@ -52,8 +52,8 @@ ABIArgInfo DefaultABIInfo::classifyReturnType(QualType RetTy) const {
     return getNaturalAlignIndirect(RetTy, getDataLayout().getAllocaAddrSpace());
 
   // Treat an enum type as its underlying type.
-  if (const EnumType *EnumTy = RetTy->getAs<EnumType>())
-    RetTy = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = RetTy->getAsEnumDecl())
+    RetTy = ED->getIntegerType();
 
   if (const auto *EIT = RetTy->getAs<BitIntType>())
     if (EIT->getNumBits() >
@@ -114,7 +114,7 @@ CGCXXABI::RecordArgABI CodeGen::getRecordArgABI(const RecordType *RT,
 }
 
 CGCXXABI::RecordArgABI CodeGen::getRecordArgABI(QualType T, CGCXXABI &CXXABI) {
-  const RecordType *RT = T->getAs<RecordType>();
+  const RecordType *RT = T->getAsCanonical<RecordType>();
   if (!RT)
     return CGCXXABI::RAA_Default;
   return getRecordArgABI(RT, CXXABI);
@@ -124,13 +124,11 @@ bool CodeGen::classifyReturnType(const CGCXXABI &CXXABI, CGFunctionInfo &FI,
                                  const ABIInfo &Info) {
   QualType Ty = FI.getReturnType();
 
-  if (const auto *RT = Ty->getAs<RecordType>()) {
-    const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
-    if (!isa<CXXRecordDecl>(RD) && !RD->canPassInRegisters()) {
-      FI.getReturnInfo() = Info.getNaturalAlignIndirect(
-          Ty, Info.getDataLayout().getAllocaAddrSpace());
-      return true;
-    }
+  if (const auto *RD = Ty->getAsRecordDecl();
+      RD && !isa<CXXRecordDecl>(RD) && !RD->canPassInRegisters()) {
+    FI.getReturnInfo() = Info.getNaturalAlignIndirect(
+        Ty, Info.getDataLayout().getAllocaAddrSpace());
+    return true;
   }
 
   return CXXABI.classifyReturnType(FI);
@@ -262,7 +260,7 @@ bool CodeGen::isEmptyField(ASTContext &Context, const FieldDecl *FD,
       WasArray = true;
     }
 
-  const RecordType *RT = FT->getAs<RecordType>();
+  const RecordType *RT = FT->getAsCanonical<RecordType>();
   if (!RT)
     return false;
 
@@ -285,10 +283,9 @@ bool CodeGen::isEmptyField(ASTContext &Context, const FieldDecl *FD,
 
 bool CodeGen::isEmptyRecord(ASTContext &Context, QualType T, bool AllowArrays,
                             bool AsIfNoUniqueAddr) {
-  const RecordType *RT = T->getAs<RecordType>();
-  if (!RT)
+  const auto *RD = T->getAsRecordDecl();
+  if (!RD)
     return false;
-  const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
   if (RD->hasFlexibleArrayMember())
     return false;
 
@@ -316,12 +313,10 @@ bool CodeGen::isEmptyFieldForLayout(const ASTContext &Context,
 }
 
 bool CodeGen::isEmptyRecordForLayout(const ASTContext &Context, QualType T) {
-  const RecordType *RT = T->getAs<RecordType>();
-  if (!RT)
+  const auto *RD = T->getAsRecordDecl();
+  if (!RD)
     return false;
 
-  const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
-
   // If this is a C++ record, check the bases first.
   if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
     if (CXXRD->isDynamicClass())
@@ -340,11 +335,10 @@ bool CodeGen::isEmptyRecordForLayout(const ASTContext &Context, QualType T) {
 }
 
 const Type *CodeGen::isSingleElementStruct(QualType T, ASTContext &Context) {
-  const RecordType *RT = T->getAs<RecordType>();
-  if (!RT)
+  const auto *RD = T->getAsRecordDecl();
+  if (!RD)
     return nullptr;
 
-  const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
   if (RD->hasFlexibleArrayMember())
     return nullptr;
 
@@ -460,10 +454,9 @@ bool CodeGen::isSIMDVectorType(ASTContext &Context, QualType Ty) {
 }
 
 bool CodeGen::isRecordWithSIMDVectorType(ASTContext &Context, QualType Ty) {
-  const RecordType *RT = Ty->getAs<RecordType>();
-  if (!RT)
+  const auto *RD = Ty->getAsRecordDecl();
+  if (!RD)
     return false;
-  const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
 
   // If this is a C++ record, check the bases first.
   if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD))
diff --git a/clang/lib/CodeGen/CGBlocks.cpp b/clang/lib/CodeGen/CGBlocks.cpp
index 74d92ef038eb9..597127abc9120 100644
--- a/clang/lib/CodeGen/CGBlocks.cpp
+++ b/clang/lib/CodeGen/CGBlocks.cpp
@@ -420,14 +420,11 @@ static void addBlockLayout(CharUnits align, CharUnits size,
 
 /// Determines if the given type is safe for constant capture in C++.
 static bool isSafeForCXXConstantCapture(QualType type) {
-  const RecordType *recordType =
-    type->getBaseElementTypeUnsafe()->getAs<RecordType>();
+  const auto *record = type->getBaseElementTypeUnsafe()->getAsCXXRecordDecl();
 
   // Only records can be unsafe.
-  if (!recordType) return true;
-
-  const auto *record =
-      cast<CXXRecordDecl>(recordType->getOriginalDecl())->getDefinitionOrSelf();
+  if (!record)
+    return true;
 
   // Maintain semantics for classes with non-trivial dtors or copy ctors.
   if (!record->hasTrivialDestructor()) return false;
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index d9cc37d123fb4..9ef38e0a340c3 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -997,9 +997,8 @@ static const FieldDecl *FindFlexibleArrayMemberField(CodeGenFunction &CGF,
             /*IgnoreTemplateOrMacroSubstitution=*/true))
       return FD;
 
-    if (auto RT = FD->getType()->getAs<RecordType>())
-      if (const FieldDecl *FD =
-              FindFlexibleArrayMemberField(CGF, Ctx, RT->getAsRecordDecl()))
+    if (const auto *RD = FD->getType()->getAsRecordDecl())
+      if (const FieldDecl *FD = FindFlexibleArrayMemberField(CGF, Ctx, RD))
         return FD;
   }
 
@@ -1025,8 +1024,8 @@ static bool GetFieldOffset(ASTContext &Ctx, const RecordDecl *RD,
       return true;
     }
 
-    if (auto RT = Field->getType()->getAs<RecordType>()) {
-      if (GetFieldOffset(Ctx, RT->getAsRecordDecl(), FD, Offset)) {
+    if (const auto *RD = Field->getType()->getAsRecordDecl()) {
+      if (GetFieldOffset(Ctx, RD, FD, Offset)) {
         Offset += Layout.getFieldOffset(FieldNo);
         return true;
       }
diff --git a/clang/lib/CodeGen/CGCUDANV.cpp b/clang/lib/CodeGen/CGCUDANV.cpp
index c7f4bf8a21354..5090a0559eab2 100644
--- a/clang/lib/CodeGen/CGCUDANV.cpp
+++ b/clang/lib/CodeGen/CGCUDANV.cpp
@@ -1131,8 +1131,7 @@ void CGNVCUDARuntime::handleVarRegistration(const VarDecl *D,
     // Builtin surfaces and textures and their template arguments are
     // also registered with CUDA runtime.
     const auto *TD = cast<ClassTemplateSpecializationDecl>(
-                         D->getType()->castAs<RecordType>()->getOriginalDecl())
-                         ->getDefinitionOrSelf();
+        D->getType()->castAsCXXRecordDecl());
     const TemplateArgumentList &Args = TD->getTemplateArgs();
     if (TD->hasAttr<CUDADeviceBuiltinSurfaceTypeAttr>()) {
       assert(Args.size() == 2 &&
diff --git a/clang/lib/CodeGen/CGCXX.cpp b/clang/lib/CodeGen/CGCXX.cpp
index f9aff893eb0f0..59aeff6804b61 100644
--- a/clang/lib/CodeGen/CGCXX.cpp
+++ b/clang/lib/CodeGen/CGCXX.cpp
@@ -83,9 +83,7 @@ bool CodeGenModule::TryEmitBaseDestructorAsAlias(const CXXDestructorDecl *D) {
     if (I.isVirtual()) continue;
 
     // Skip base classes with trivial destructors.
-    const auto *Base = cast<CXXRecordDecl>(
-                           I.getType()->castAs<RecordType>()->getOriginalDecl())
-                           ->getDefinitionOrSelf();
+    const auto *Base = I.getType()->castAsCXXRecordDecl();
     if (Base->hasTrivialDestructor()) continue;
 
     // If we've already found a base class with a non-trivial
@@ -281,16 +279,8 @@ static CGCallee BuildAppleKextVirtualCall(CodeGenFunction &CGF,
 CGCallee CodeGenFunction::BuildAppleKextVirtualCall(const CXXMethodDecl *MD,
                                                     NestedNameSpecifier Qual,
                                                     llvm::Type *Ty) {
-  assert(Qual.getKind() == NestedNameSpecifier::Kind::Type &&
-         "BuildAppleKextVirtualCall - bad Qual kind");
-
-  const Type *QTy = Qual.getAsType();
-  QualType T = QualType(QTy, 0);
-  const RecordType *RT = T->getAs<RecordType>();
-  assert(RT && "BuildAppleKextVirtualCall - Qual type must be record");
-  const auto *RD =
-      cast<CXXRecordDecl>(RT->getOriginalDecl())->getDefinitionOrSelf();
-
+  const CXXRecordDecl *RD = Qual.getAsRecordDecl();
+  assert(RD && "BuildAppleKextVirtualCall - Qual must be record");
   if (const auto *DD = dyn_cast<CXXDestructorDecl>(MD))
     return BuildAppleKextVirtualDestructorCall(DD, Dtor_Complete, RD);
 
diff --git a/clang/lib/CodeGen/CGCXXABI.cpp b/clang/lib/CodeGen/CGCXXABI.cpp
index cca675838644e..30e5dc2b6cbd9 100644
--- a/clang/lib/CodeGen/CGCXXABI.cpp
+++ b/clang/lib/CodeGen/CGCXXABI.cpp
@@ -165,10 +165,7 @@ bool CGCXXABI::mayNeedDestruction(const VarDecl *VD) const {
   // If the variable has an incomplete class type (or array thereof), it
   // might need destruction.
   const Type *T = VD->getType()->getBaseElementTypeUnsafe();
-  if (T->getAs<RecordType>() && T->isIncompleteType())
-    return true;
-
-  return false;
+  return T->isRecordType() && T->isIncompleteType();
 }
 
 bool CGCXXABI::isEmittedWithConstantInitializer(
diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index 466546b2f1309..ce00e1f6e6b0f 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -1005,10 +1005,9 @@ getTypeExpansion(QualType Ty, const ASTContext &Context) {
     return std::make_unique<ConstantArrayExpansion>(AT->getElementType(),
                                                     AT->getZExtSize());
   }
-  if (const RecordType *RT = Ty->getAs<RecordType>()) {
+  if (const auto *RD = Ty->getAsRecordDecl()) {
     SmallVector<const CXXBaseSpecifier *, 1> Bases;
     SmallVector<const FieldDecl *, 1> Fields;
-    const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
     assert(!RD->hasFlexibleArrayMember() &&
            "Cannot expand structure with flexible array.");
     if (RD->isUnion()) {
@@ -1897,7 +1896,7 @@ bool CodeGenModule::MayDropFunctionReturn(const ASTContext &Context,
   // We can't just discard the return value for a record type with a
   // complex destructor or a non-trivially copyable type.
   if (const RecordType *RT =
-          ReturnType.getCanonicalType()->getAs<RecordType>()) {
+          ReturnType.getCanonicalType()->getAsCanonical<RecordType>()) {
     if (const auto *ClassDecl = dyn_cast<CXXRecordDecl>(RT->getOriginalDecl()))
       return ClassDecl->hasTrivialDestructor();
   }
@@ -2874,10 +2873,7 @@ void CodeGenModule::ConstructAttributeList(StringRef Name,
           // whose destruction / clean-up is carried out within the callee
           // (e.g., Obj-C ARC-managed structs, MSVC callee-destroyed objects).
           if (!ParamType.isDestructedType() || !ParamType->isRecordType() ||
-              ParamType->castAs<RecordType>()
-                  ->getOriginalDecl()
-                  ->getDefinitionOrSelf()
-                  ->isParamDestroyedInCallee())
+              ParamType->castAsRecordDecl()->isParamDestroyedInCallee())
             Attrs.addAttribute(llvm::Attribute::DeadOnReturn);
         }
       }
@@ -3890,7 +3886,7 @@ static void setUsedBits(CodeGenModule &CGM, const ConstantArrayType *ATy,
 // the type `QTy`.
 static void setUsedBits(CodeGenModule &CGM, QualType QTy, int Offset,
                         SmallVectorImpl<uint64_t> &Bits) {
-  if (const auto *RTy = QTy->getAs<RecordType>())
+  if (const auto *RTy = QTy->getAsCanonical<RecordType>())
     return setUsedBits(CGM, RTy, Offset, Bits);
 
   ASTContext &Context = CGM.getContext();
@@ -3934,7 +3930,7 @@ llvm::Value *CodeGenFunction::EmitCMSEClearRecord(llvm::Value *Src,
   const llvm::DataLayout &DataLayout = CGM.getDataLayout();
   int Size = DataLayout.getTypeStoreSize(ITy);
   SmallVector<uint64_t, 4> Bits(Size);
-  setUsedBits(CGM, QTy->castAs<RecordType>(), 0, Bits);
+  setUsedBits(CGM, QTy->castAsCanonical<RecordType>(), 0, Bits);
 
   int CharWidth = CGM.getContext().getCharWidth();
   uint64_t Mask =
@@ -3951,7 +3947,7 @@ llvm::Value *CodeGenFunction::EmitCMSEClearRecord(llvm::Value *Src,
   const llvm::DataLayout &DataLayout = CGM.getDataLayout();
   int Size = DataLayout.getTypeStoreSize(ATy);
   SmallVector<uint64_t, 16> Bits(Size);
-  setUsedBits(CGM, QTy->castAs<RecordType>(), 0, Bits);
+  setUsedBits(CGM, QTy->castAsCanonical<RecordType>(), 0, Bits);
 
   // Clear each element of the LLVM array.
   int CharWidth = CGM.getContext().getCharWidth();
@@ -4308,10 +4304,7 @@ void CodeGenFunction::EmitDelegateCallArg(CallArgList &args,
 
   // Deactivate the cleanup for the callee-destructed param that was pushed.
   if (type->isRecordType() && !CurFuncIsThunk &&
-      type->castAs<RecordType>()
-          ->getOriginalDecl()
-          ->getDefinitionOrSelf()
-          ->isParamDestroyedInCallee() &&
+      type->castAsRecordDecl()->isParamDestroyedInCallee() &&
       param->needsDestruction(getContext())) {
     EHScopeStack::stable_iterator cleanup =
         CalleeDestructedParamCleanups.lookup(cast<ParmVarDecl>(param));
@@ -4904,10 +4897,8 @@ void CodeGenFunction::EmitCallArg(CallArgList &args, const Expr *E,
   // In the Microsoft C++ ABI, aggregate arguments are destructed by the callee.
   // However, we still have to push an EH-only cleanup in case we unwind before
   // we make it to the call.
-  if (type->isRecordType() && type->castAs<RecordType>()
-                                  ->getOriginalDecl()
-                                  ->getDefinitionOrSelf()
-                                  ->isParamDestroyedInCallee()) {
+  if (type->isRecordType() &&
+      type->castAsRecordDecl()->isParamDestroyedInCallee()) {
     // If we're using inalloca, use the argument memory.  Otherwise, use a
     // temporary.
     AggValueSlot Slot = args.isUsingInAlloca()
diff --git a/clang/lib/CodeGen/CGClass.cpp b/clang/lib/CodeGen/CGClass.cpp
index e9a92ae0f01cb..bae55aa1e1928 100644
--- a/clang/lib/CodeGen/CGClass.cpp
+++ b/clang/lib/CodeGen/CGClass.cpp
@@ -180,11 +180,7 @@ CharUnits CodeGenModule::computeNonVirtualBaseClassOffset(
     // Get the layout.
     const ASTRecordLayout &Layout = Context.getASTRecordLayout(RD);
 
-    const auto *BaseDecl =
-        cast<CXXRecordDecl>(
-            Base->getType()->castAs<RecordType>()->getOriginalDecl())
-            ->getDefinitionOrSelf();
-
+    const auto *BaseDecl = Base->getType()->castAsCXXRecordDecl();
     // Add the offset.
     Offset += Layout.getBaseClassOffset(BaseDecl);
 
@@ -302,9 +298,7 @@ Address CodeGenFunction::GetAddressOfBaseClass(
   // *start* with a step down to the correct virtual base subobject,
   // and hence will not require any further steps.
   if ((*Start)->isVirtual()) {
-    VBase = cast<CXXRecordDecl>(
-                (*Start)->getType()->castAs<RecordType>()->getOriginalDecl())
-                ->getDefinitionOrSelf();
+    VBase = (*Start)->getType()->castAsCXXRecordDecl();
     ++Start;
   }
 
@@ -559,10 +553,7 @@ static void EmitBaseInitializer(CodeGenFunction &CGF,
 
   Address ThisPtr = CGF.LoadCXXThisAddress();
 
-  const Type *BaseType = BaseInit->getBaseClass();
-  const auto *BaseClassDecl =
-      cast<CXXRecordDecl>(BaseType->castAs<RecordType>()->getOriginalDecl())
-          ->getDefinitionOrSelf();
+  const auto *BaseClassDecl = BaseInit->getBaseClass()->castAsCXXRecordDecl();
 
   bool isBaseVirtual = BaseInit->isBaseVirtual();
 
@@ -1267,10 +1258,7 @@ namespace {
 
 static bool isInitializerOfDynamicClass(const CXXCtorInitializer *BaseInit) {
   const Type *BaseType = BaseInit->getBaseClass();
-  const auto *BaseClassDecl =
-      cast<CXXRecordDecl>(BaseType->castAs<RecordType>()->getOriginalDecl())
-          ->getDefinitionOrSelf();
-  return BaseClassDecl->isDynamicClass();
+  return BaseType->castAsCXXRecordDecl()->isDynamicClass();
 }
 
 /// EmitCtorPrologue - This routine generates necessary code to initialize
@@ -1377,10 +1365,7 @@ HasTrivialDestructorBody(ASTContext &Context,
     if (I.isVirtual())
       continue;
 
-    const CXXRecordDecl *NonVirtualBase =
-        cast<CXXRecordDecl>(
-            I.getType()->castAs<RecordType>()->getOriginalDecl())
-            ->getDefinitionOrSelf();
+    const auto *NonVirtualBase = I.getType()->castAsCXXRecordDecl();
     if (!HasTrivialDestructorBody(Context, NonVirtualBase,
                                   MostDerivedClassDecl))
       return false;
@@ -1389,10 +1374,7 @@ HasTrivialDestructorBody(ASTContext &Context,
   if (BaseClassDecl == MostDerivedClassDecl) {
     // Check virtual bases.
     for (const auto &I : BaseClassDecl->vbases()) {
-      const auto *VirtualBase =
-          cast<CXXRecordDecl>(
-              I.getType()->castAs<RecordType>()->getOriginalDecl())
-              ->getDefinitionOrSelf();
+      const auto *VirtualBase = I.getType()->castAsCXXRecordDecl();
       if (!HasTrivialDestructorBody(Context, VirtualBase,
                                     MostDerivedClassDecl))
         return false;
@@ -1408,13 +1390,10 @@ FieldHasTrivialDestructorBody(ASTContext &Context,
 {
   QualType FieldBaseElementType = Context.getBaseElementType(Field->getType());
 
-  const RecordType *RT = FieldBaseElementType->getAs<RecordType>();
-  if (!RT)
+  auto *FieldClassDecl = FieldBaseElementType->getAsCXXRecordDecl();
+  if (!FieldClassDecl)
     return true;
 
-  auto *FieldClassDecl =
-      cast<CXXRecordDecl>(RT->getOriginalDecl())->getDefinitionOrSelf();
-
   // The destructor for an implicit anonymous union member is never invoked.
   if (FieldClassDecl->isUnion() && FieldClassDecl->isAnonymousStructOrUnion())
     return true;
@@ -1907,11 +1886,7 @@ void CodeGenFunction::EnterDtorCleanups(const CXXDestructorDecl *DD,
     // We push them in the forward order so that they'll be popped in
     // the reverse order.
     for (const auto &Base : ClassDecl->vbases()) {
-      auto *BaseClassDecl =
-          cast<CXXRecordDecl>(
-              Base.getType()->castAs<RecordType>()->getOriginalDecl())
-              ->getDefinitionOrSelf();
-
+      auto *BaseClassDecl = Base.getType()->castAsCXXRecordDecl();
       if (BaseClassDecl->hasTrivialDestructor()) {
         // Under SanitizeMemoryUseAfterDtor, poison the trivial base class
         // memory. For non-trival base classes the same is done in the class
@@ -2130,10 +2105,7 @@ void CodeGenFunction::EmitCXXAggrConstructorCall(const CXXConstructorDecl *ctor,
 void CodeGenFunction::destroyCXXObject(CodeGenFunction &CGF,
                                        Address addr,
                                        QualType type) {
-  const RecordType *rtype = type->castAs<RecordType>();
-  const auto *record =
-      cast<CXXRecordDecl>(rtype->getOriginalDecl())->getDefinitionOrSelf();
-  const CXXDestructorDecl *dtor = record->getDestructor();
+  const CXXDestructorDecl *dtor = type->castAsCXXRecordDecl()->getDestructor();
   assert(!dtor->isTrivial());
   CGF.EmitCXXDestructorCall(dtor, Dtor_Complete, /*for vbase*/ false,
                             /*Delegating=*/false, addr, type);
@@ -2652,10 +2624,7 @@ void CodeGenFunction::getVTablePointers(BaseSubobject Base,
 
   // Traverse bases.
   for (const auto &I : RD->bases()) {
-    auto *BaseDecl = cast<CXXRecordDecl>(
-                         I.getType()->castAs<RecordType>()->getOriginalDecl())
-                         ->getDefinitionOrSelf();
-
+    auto *BaseDecl = I.getType()->castAsCXXRecordDecl();
     // Ignore classes without a vtable.
     if (!BaseDecl->isDynamicClass())
       continue;
@@ -2850,13 +2819,10 @@ void CodeGenFunction::EmitVTablePtrCheckForCast(QualType T, Address Derived,
   if (!getLangOpts().CPlusPlus)
     return;
 
-  auto *ClassTy = T->getAs<RecordType>();
-  if (!ClassTy)
+  const auto *ClassDecl = T->getAsCXXRecordDecl();
+  if (!ClassDecl)
     return;
 
-  const auto *ClassDecl =
-      cast<CXXRecordDecl>(ClassTy->getOriginalDecl())->getDefinitionOrSelf();
-
   if (!ClassDecl->isCompleteDefinition() || !ClassDecl->isDynamicClass())
     return;
 
diff --git a/clang/lib/CodeGen/CGDebugInfo.cpp b/clang/lib/CodeGen/CGDebugInfo.cpp
index c44fea3e6383d..0385dbdac869b 100644
--- a/clang/lib/CodeGen/CGDebugInfo.cpp
+++ b/clang/lib/CodeGen/CGDebugInfo.cpp
@@ -2359,7 +2359,7 @@ void CGDebugInfo::CollectCXXBasesAux(
   for (const auto &BI : Bases) {
     const auto *Base =
         cast<CXXRecordDecl>(
-            BI.getType()->castAs<RecordType>()->getOriginalDecl())
+            BI.getType()->castAsCanonical<RecordType>()->getOriginalDecl())
             ->getDefinition();
     if (!SeenTypes.insert(Base).second)
       continue;
@@ -5921,8 +5921,7 @@ void CGDebugInfo::EmitGlobalVariable(llvm::GlobalVariable *Var,
   // variable for each member of the anonymous union so that it's possible
   // to find the name of any field in the union.
   if (T->isUnionType() && DeclName.empty()) {
-    const RecordDecl *RD =
-        T->castAs<RecordType>()->getOriginalDecl()->getDefinitionOrSelf();
+    const auto *RD = T->castAsRecordDecl();
     assert(RD->isAnonymousStructOrUnion() &&
            "unnamed non-anonymous struct or union?");
     GVE = CollectAnonRecordDecls(RD, Unit, LineNo, LinkageName, Var, DContext);
diff --git a/clang/lib/CodeGen/CGDecl.cpp b/clang/lib/CodeGen/CGDecl.cpp
index 8a1675848e13c..29193e0c541b9 100644
--- a/clang/lib/CodeGen/CGDecl.cpp
+++ b/clang/lib/CodeGen/CGDecl.cpp
@@ -1568,10 +1568,9 @@ CodeGenFunction::EmitAutoVarAlloca(const VarDecl &D) {
                      ReturnValue.getElementType(), ReturnValue.getAlignment());
       address = MaybeCastStackAddressSpace(AllocaAddr, Ty.getAddressSpace());
 
-      if (const RecordType *RecordTy = Ty->getAs<RecordType>()) {
-        const auto *RD = RecordTy->getOriginalDecl()->getDefinitionOrSelf();
-        const auto *CXXRD = dyn_cast<CXXRecordDecl>(RD);
-        if ((CXXRD && !CXXRD->hasTrivialDestructor()) ||
+      if (const auto *RD = Ty->getAsRecordDecl()) {
+        if (const auto *CXXRD = dyn_cast<CXXRecordDecl>(RD);
+            (CXXRD && !CXXRD->hasTrivialDestructor()) ||
             RD->isNonTrivialToPrimitiveDestroy()) {
           // Create a flag that is used to indicate when the NRVO was applied
           // to this variable. Set it to zero to indicate that NRVO was not
@@ -2728,10 +2727,7 @@ void CodeGenFunction::EmitParmDecl(const VarDecl &D, ParamValue Arg,
     // Don't push a cleanup in a thunk for a method that will also emit a
     // cleanup.
     if (Ty->isRecordType() && !CurFuncIsThunk &&
-        Ty->castAs<RecordType>()
-            ->getOriginalDecl()
-            ->getDefinitionOrSelf()
-            ->isParamDestroyedInCallee()) {
+        Ty->castAsRecordDecl()->isParamDestroyedInCallee()) {
       if (QualType::DestructionKind DtorKind =
               D.needsDestruction(getContext())) {
         assert((DtorKind == QualType::DK_cxx_destructor ||
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 2329fa20a2530..9e1b886263ed0 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -414,14 +414,11 @@ pushTemporaryCleanup(CodeGenFunction &CGF, const MaterializeTemporaryExpr *M,
     case SD_Static:
     case SD_Thread: {
       CXXDestructorDecl *ReferenceTemporaryDtor = nullptr;
-      if (const RecordType *RT =
-              E->getType()->getBaseElementTypeUnsafe()->getAs<RecordType>()) {
+      if (const auto *ClassDecl =
+              E->getType()->getBaseElementTypeUnsafe()->getAsCXXRecordDecl();
+          ClassDecl && !ClassDecl->hasTrivialDestructor())
         // Get the destructor for the reference temporary.
-        if (auto *ClassDecl = dyn_cast<CXXRecordDecl>(RT->getOriginalDecl());
-            ClassDecl && !ClassDecl->hasTrivialDestructor())
-          ReferenceTemporaryDtor =
-              ClassDecl->getDefinitionOrSelf()->getDestructor();
-      }
+        ReferenceTemporaryDtor = ClassDecl->getDestructor();
 
       if (!ReferenceTemporaryDtor)
         return;
@@ -1929,11 +1926,9 @@ llvm::Value *CodeGenFunction::EmitLoadOfScalar(LValue lvalue,
 static bool getRangeForType(CodeGenFunction &CGF, QualType Ty,
                             llvm::APInt &Min, llvm::APInt &End,
                             bool StrictEnums, bool IsBool) {
-  const EnumType *ET = Ty->getAs<EnumType>();
-  const EnumDecl *ED =
-      ET ? ET->getOriginalDecl()->getDefinitionOrSelf() : nullptr;
+  const auto *ED = Ty->getAsEnumDecl();
   bool IsRegularCPlusPlusEnum =
-      CGF.getLangOpts().CPlusPlus && StrictEnums && ET && !ED->isFixed();
+      CGF.getLangOpts().CPlusPlus && StrictEnums && ED && !ED->isFixed();
   if (!IsBool && !IsRegularCPlusPlusEnum)
     return false;
 
@@ -1980,7 +1975,7 @@ bool CodeGenFunction::EmitScalarRangeCheck(llvm::Value *Value, QualType Ty,
   bool IsBool = (Ty->hasBooleanRepresentation() && !Ty->isVectorType()) ||
                 NSAPI(CGM.getContext()).isObjCBOOLType(Ty);
   bool NeedsBoolCheck = HasBoolCheck && IsBool;
-  bool NeedsEnumCheck = HasEnumCheck && Ty->getAs<EnumType>();
+  bool NeedsEnumCheck = HasEnumCheck && Ty->isEnumeralType();
   if (!NeedsBoolCheck && !NeedsEnumCheck)
     return false;
 
@@ -5683,12 +5678,7 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) {
 
   case CK_UncheckedDerivedToBase:
   case CK_DerivedToBase: {
-    const auto *DerivedClassTy =
-        E->getSubExpr()->getType()->castAs<RecordType>();
-    auto *DerivedClassDecl =
-        cast<CXXRecordDecl>(DerivedClassTy->getOriginalDecl())
-            ->getDefinitionOrSelf();
-
+    auto *DerivedClassDecl = E->getSubExpr()->getType()->castAsCXXRecordDecl();
     LValue LV = EmitLValue(E->getSubExpr());
     Address This = LV.getAddress();
 
@@ -5706,11 +5696,7 @@ LValue CodeGenFunction::EmitCastLValue(const CastExpr *E) {
   case CK_ToUnion:
     return EmitAggExprToLValue(E);
   case CK_BaseToDerived: {
-    const auto *DerivedClassTy = E->getType()->castAs<RecordType>();
-    auto *DerivedClassDecl =
-        cast<CXXRecordDecl>(DerivedClassTy->getOriginalDecl())
-            ->getDefinitionOrSelf();
-
+    auto *DerivedClassDecl = E->getType()->castAsCXXRecordDecl();
     LValue LV = EmitLValue(E->getSubExpr());
 
     // Perform the base-to-derived conversion
diff --git a/clang/lib/CodeGen/CGExprAgg.cpp b/clang/lib/CodeGen/CGExprAgg.cpp
index 04e125c54b1ca..b8150a24d45fc 100644
--- a/clang/lib/CodeGen/CGExprAgg.cpp
+++ b/clang/lib/CodeGen/CGExprAgg.cpp
@@ -268,11 +268,11 @@ void AggExprEmitter::EmitAggLoadOfLValue(const Expr *E) {
 /// True if the given aggregate type requires special GC API calls.
 bool AggExprEmitter::TypeRequiresGCollection(QualType T) {
   // Only record types have members that might require garbage collection.
-  const RecordType *RecordTy = T->getAs<RecordType>();
-  if (!RecordTy) return false;
+  const auto *Record = T->getAsRecordDecl();
+  if (!Record)
+    return false;
 
   // Don't mess with non-trivial C++ types.
-  RecordDecl *Record = RecordTy->getOriginalDecl()->getDefinitionOrSelf();
   if (isa<CXXRecordDecl>(Record) &&
       (cast<CXXRecordDecl>(Record)->hasNonTrivialCopyConstructor() ||
        !cast<CXXRecordDecl>(Record)->hasTrivialDestructor()))
@@ -424,10 +424,7 @@ AggExprEmitter::VisitCXXStdInitializerListExpr(CXXStdInitializerListExpr *E) {
       Ctx.getAsConstantArrayType(E->getSubExpr()->getType());
   assert(ArrayType && "std::initializer_list constructed from non-array");
 
-  RecordDecl *Record = E->getType()
-                           ->castAs<RecordType>()
-                           ->getOriginalDecl()
-                           ->getDefinitionOrSelf();
+  auto *Record = E->getType()->castAsRecordDecl();
   RecordDecl::field_iterator Field = Record->field_begin();
   assert(Field != Record->field_end() &&
          Ctx.hasSameType(Field->getType()->getPointeeType(),
@@ -1809,10 +1806,7 @@ void AggExprEmitter::VisitCXXParenListOrInitListExpr(
   // the disadvantage is that the generated code is more difficult for
   // the optimizer, especially with bitfields.
   unsigned NumInitElements = InitExprs.size();
-  RecordDecl *record = ExprToVisit->getType()
-                           ->castAs<RecordType>()
-                           ->getOriginalDecl()
-                           ->getDefinitionOrSelf();
+  RecordDecl *record = ExprToVisit->getType()->castAsRecordDecl();
 
   // We'll need to enter cleanup scopes in case any of the element
   // initializers throws an exception.
@@ -2120,7 +2114,7 @@ static CharUnits GetNumNonZeroBytesInInit(const Expr *E, CodeGenFunction &CGF) {
   // InitListExprs for structs have to be handled carefully.  If there are
   // reference members, we need to consider the size of the reference, not the
   // referencee.  InitListExprs for unions and arrays can't have references.
-  if (const RecordType *RT = E->getType()->getAs<RecordType>()) {
+  if (const RecordType *RT = E->getType()->getAsCanonical<RecordType>()) {
     if (!RT->isUnionType()) {
       RecordDecl *SD = RT->getOriginalDecl()->getDefinitionOrSelf();
       CharUnits NumNonZeroBytes = CharUnits::Zero();
@@ -2173,7 +2167,8 @@ static void CheckAggExprForMemSetUse(AggValueSlot &Slot, const Expr *E,
   // C++ objects with a user-declared constructor don't need zero'ing.
   if (CGF.getLangOpts().CPlusPlus)
     if (const RecordType *RT = CGF.getContext()
-                       .getBaseElementType(E->getType())->getAs<RecordType>()) {
+                                   .getBaseElementType(E->getType())
+                                   ->getAsCanonical<RecordType>()) {
       const CXXRecordDecl *RD = cast<CXXRecordDecl>(RT->getOriginalDecl());
       if (RD->hasUserDeclaredConstructor())
         return;
@@ -2294,9 +2289,7 @@ void CodeGenFunction::EmitAggregateCopy(LValue Dest, LValue Src, QualType Ty,
   Address SrcPtr = Src.getAddress();
 
   if (getLangOpts().CPlusPlus) {
-    if (const RecordType *RT = Ty->getAs<RecordType>()) {
-      auto *Record =
-          cast<CXXRecordDecl>(RT->getOriginalDecl())->getDefinitionOrSelf();
+    if (const auto *Record = Ty->getAsCXXRecordDecl()) {
       assert((Record->hasTrivialCopyConstructor() ||
               Record->hasTrivialCopyAssignment() ||
               Record->hasTrivialMoveConstructor() ||
@@ -2379,8 +2372,7 @@ void CodeGenFunction::EmitAggregateCopy(LValue Dest, LValue Src, QualType Ty,
   // Don't do any of the memmove_collectable tests if GC isn't set.
   if (CGM.getLangOpts().getGC() == LangOptions::NonGC) {
     // fall through
-  } else if (const RecordType *RecordTy = Ty->getAs<RecordType>()) {
-    RecordDecl *Record = RecordTy->getOriginalDecl()->getDefinitionOrSelf();
+  } else if (const auto *Record = Ty->getAsRecordDecl()) {
     if (Record->hasObjectMember()) {
       CGM.getObjCRuntime().EmitGCMemmoveCollectable(*this, DestPtr, SrcPtr,
                                                     SizeVal);
@@ -2388,10 +2380,8 @@ void CodeGenFunction::EmitAggregateCopy(LValue Dest, LValue Src, QualType Ty,
     }
   } else if (Ty->isArrayType()) {
     QualType BaseType = getContext().getBaseElementType(Ty);
-    if (const RecordType *RecordTy = BaseType->getAs<RecordType>()) {
-      if (RecordTy->getOriginalDecl()
-              ->getDefinitionOrSelf()
-              ->hasObjectMember()) {
+    if (const auto *Record = BaseType->getAsRecordDecl()) {
+      if (Record->hasObjectMember()) {
         CGM.getObjCRuntime().EmitGCMemmoveCollectable(*this, DestPtr, SrcPtr,
                                                       SizeVal);
         return;
diff --git a/clang/lib/CodeGen/CGExprCXX.cpp b/clang/lib/CodeGen/CGExprCXX.cpp
index 57d7eecc62ee1..1e4c72a210f9a 100644
--- a/clang/lib/CodeGen/CGExprCXX.cpp
+++ b/clang/lib/CodeGen/CGExprCXX.cpp
@@ -180,8 +180,7 @@ static CXXRecordDecl *getCXXRecord(const Expr *E) {
   QualType T = E->getType();
   if (const PointerType *PTy = T->getAs<PointerType>())
     T = PTy->getPointeeType();
-  const RecordType *Ty = T->castAs<RecordType>();
-  return cast<CXXRecordDecl>(Ty->getOriginalDecl())->getDefinitionOrSelf();
+  return T->castAsCXXRecordDecl();
 }
 
 // Note: This function also emit constructor calls to support a MSVC
@@ -1235,9 +1234,10 @@ void CodeGenFunction::EmitNewArrayInitializer(
   // If we have a struct whose every field is value-initialized, we can
   // usually use memset.
   if (auto *ILE = dyn_cast<InitListExpr>(Init)) {
-    if (const RecordType *RType = ILE->getType()->getAs<RecordType>()) {
-      const RecordDecl *RD = RType->getOriginalDecl()->getDefinitionOrSelf();
-      if (RD->isStruct()) {
+    if (const RecordType *RType =
+            ILE->getType()->getAsCanonical<RecordType>()) {
+      if (RType->getOriginalDecl()->isStruct()) {
+        const RecordDecl *RD = RType->getOriginalDecl()->getDefinitionOrSelf();
         unsigned NumElements = 0;
         if (auto *CXXRD = dyn_cast<CXXRecordDecl>(RD))
           NumElements = CXXRD->getNumBases();
@@ -1687,11 +1687,8 @@ llvm::Value *CodeGenFunction::EmitCXXNewExpr(const CXXNewExpr *E) {
       QualType AlignValT = sizeType;
       if (allocatorType->getNumParams() > IndexOfAlignArg) {
         AlignValT = allocatorType->getParamType(IndexOfAlignArg);
-        assert(getContext().hasSameUnqualifiedType(AlignValT->castAs<EnumType>()
-                                                       ->getOriginalDecl()
-                                                       ->getDefinitionOrSelf()
-                                                       ->getIntegerType(),
-                                                   sizeType) &&
+        assert(getContext().hasSameUnqualifiedType(
+                   AlignValT->castAsEnumDecl()->getIntegerType(), sizeType) &&
                "wrong type for alignment parameter");
         ++ParamsToSkip;
       } else {
@@ -1973,9 +1970,7 @@ static bool EmitObjectDelete(CodeGenFunction &CGF,
   // Find the destructor for the type, if applicable.  If the
   // destructor is virtual, we'll just emit the vcall and return.
   const CXXDestructorDecl *Dtor = nullptr;
-  if (const RecordType *RT = ElementType->getAs<RecordType>()) {
-    auto *RD =
-        cast<CXXRecordDecl>(RT->getOriginalDecl())->getDefinitionOrSelf();
+  if (const auto *RD = ElementType->getAsCXXRecordDecl()) {
     if (RD->hasDefinition() && !RD->hasTrivialDestructor()) {
       Dtor = RD->getDestructor();
 
diff --git a/clang/lib/CodeGen/CGExprConstant.cpp b/clang/lib/CodeGen/CGExprConstant.cpp
index a96c1518d2a1d..36534025ebd0f 100644
--- a/clang/lib/CodeGen/CGExprConstant.cpp
+++ b/clang/lib/CodeGen/CGExprConstant.cpp
@@ -714,10 +714,7 @@ static bool EmitDesignatedInitUpdater(ConstantEmitter &Emitter,
 }
 
 bool ConstStructBuilder::Build(const InitListExpr *ILE, bool AllowOverwrite) {
-  RecordDecl *RD = ILE->getType()
-                       ->castAs<RecordType>()
-                       ->getOriginalDecl()
-                       ->getDefinitionOrSelf();
+  auto *RD = ILE->getType()->castAsRecordDecl();
   const ASTRecordLayout &Layout = CGM.getContext().getASTRecordLayout(RD);
 
   unsigned FieldNo = -1;
@@ -980,8 +977,7 @@ bool ConstStructBuilder::DoZeroInitPadding(const ASTRecordLayout &Layout,
 
 llvm::Constant *ConstStructBuilder::Finalize(QualType Type) {
   Type = Type.getNonReferenceType();
-  RecordDecl *RD =
-      Type->castAs<RecordType>()->getOriginalDecl()->getDefinitionOrSelf();
+  auto *RD = Type->castAsRecordDecl();
   llvm::Type *ValTy = CGM.getTypes().ConvertType(Type);
   return Builder.build(ValTy, RD->hasFlexibleArrayMember());
 }
@@ -1004,8 +1000,7 @@ llvm::Constant *ConstStructBuilder::BuildStruct(ConstantEmitter &Emitter,
   ConstantAggregateBuilder Const(Emitter.CGM);
   ConstStructBuilder Builder(Emitter, Const, CharUnits::Zero());
 
-  const RecordDecl *RD =
-      ValTy->castAs<RecordType>()->getOriginalDecl()->getDefinitionOrSelf();
+  const auto *RD = ValTy->castAsRecordDecl();
   const CXXRecordDecl *CD = dyn_cast<CXXRecordDecl>(RD);
   if (!Builder.Build(Val, RD, false, CD, CharUnits::Zero()))
     return nullptr;
@@ -1510,10 +1505,8 @@ class ConstExprEmitter
 
     llvm::Type *ValTy = CGM.getTypes().ConvertType(destType);
     bool HasFlexibleArray = false;
-    if (const auto *RT = destType->getAs<RecordType>())
-      HasFlexibleArray = RT->getOriginalDecl()
-                             ->getDefinitionOrSelf()
-                             ->hasFlexibleArrayMember();
+    if (const auto *RD = destType->getAsRecordDecl())
+      HasFlexibleArray = RD->hasFlexibleArrayMember();
     return Const.build(ValTy, HasFlexibleArray);
   }
 
@@ -2646,11 +2639,7 @@ static llvm::Constant *EmitNullConstant(CodeGenModule &CGM,
         continue;
       }
 
-      const CXXRecordDecl *base =
-          cast<CXXRecordDecl>(
-              I.getType()->castAs<RecordType>()->getOriginalDecl())
-              ->getDefinitionOrSelf();
-
+      const auto *base = I.getType()->castAsCXXRecordDecl();
       // Ignore empty bases.
       if (isEmptyRecordForLayout(CGM.getContext(), I.getType()) ||
           CGM.getContext()
@@ -2688,11 +2677,7 @@ static llvm::Constant *EmitNullConstant(CodeGenModule &CGM,
   // Fill in the virtual bases, if we're working with the complete object.
   if (CXXR && asCompleteObject) {
     for (const auto &I : CXXR->vbases()) {
-      const auto *base =
-          cast<CXXRecordDecl>(
-              I.getType()->castAs<RecordType>()->getOriginalDecl())
-              ->getDefinitionOrSelf();
-
+      const auto *base = I.getType()->castAsCXXRecordDecl();
       // Ignore empty bases.
       if (isEmptyRecordForLayout(CGM.getContext(), I.getType()))
         continue;
@@ -2756,10 +2741,9 @@ llvm::Constant *CodeGenModule::EmitNullConstant(QualType T) {
     return llvm::ConstantArray::get(ATy, Array);
   }
 
-  if (const RecordType *RT = T->getAs<RecordType>())
-    return ::EmitNullConstant(*this,
-                              RT->getOriginalDecl()->getDefinitionOrSelf(),
-                              /*complete object*/ true);
+  if (const auto *RD = T->getAsRecordDecl())
+    return ::EmitNullConstant(*this, RD,
+                              /*asCompleteObject=*/true);
 
   assert(T->isMemberDataPointerType() &&
          "Should only see pointers to data members here!");
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index fcbfa5e31d149..63bc0d82a2b56 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -3515,9 +3515,7 @@ Value *ScalarExprEmitter::VisitOffsetOfExpr(OffsetOfExpr *E) {
 
     case OffsetOfNode::Field: {
       FieldDecl *MemberDecl = ON.getField();
-      RecordDecl *RD = CurrentType->castAs<RecordType>()
-                           ->getOriginalDecl()
-                           ->getDefinitionOrSelf();
+      auto *RD = CurrentType->castAsRecordDecl();
       const ASTRecordLayout &RL = CGF.getContext().getASTRecordLayout(RD);
 
       // Compute the index of the field in its parent.
@@ -3551,15 +3549,13 @@ Value *ScalarExprEmitter::VisitOffsetOfExpr(OffsetOfExpr *E) {
       }
 
       const ASTRecordLayout &RL = CGF.getContext().getASTRecordLayout(
-          CurrentType->castAs<RecordType>()->getOriginalDecl());
+          CurrentType->castAsCanonical<RecordType>()->getOriginalDecl());
 
       // Save the element type.
       CurrentType = ON.getBase()->getType();
 
       // Compute the offset to the base.
-      auto *BaseRT = CurrentType->castAs<RecordType>();
-      auto *BaseRD =
-          cast<CXXRecordDecl>(BaseRT->getOriginalDecl())->getDefinitionOrSelf();
+      auto *BaseRD = CurrentType->castAsCXXRecordDecl();
       CharUnits OffsetInt = RL.getBaseClassOffset(BaseRD);
       Offset = llvm::ConstantInt::get(ResultType, OffsetInt.getQuantity());
       break;
diff --git a/clang/lib/CodeGen/CGNonTrivialStruct.cpp b/clang/lib/CodeGen/CGNonTrivialStruct.cpp
index 1b941fff8b644..2d70e4c2e0394 100644
--- a/clang/lib/CodeGen/CGNonTrivialStruct.cpp
+++ b/clang/lib/CodeGen/CGNonTrivialStruct.cpp
@@ -39,8 +39,7 @@ template <class Derived> struct StructVisitor {
 
   template <class... Ts>
   void visitStructFields(QualType QT, CharUnits CurStructOffset, Ts... Args) {
-    const RecordDecl *RD =
-        QT->castAs<RecordType>()->getOriginalDecl()->getDefinitionOrSelf();
+    const auto *RD = QT->castAsRecordDecl();
 
     // Iterate over the fields of the struct.
     for (const FieldDecl *FD : RD->fields()) {
@@ -674,7 +673,8 @@ struct GenDefaultInitialize
     CharUnits Size = Ctx.getTypeSizeInChars(QualType(AT, 0));
     QualType EltTy = Ctx.getBaseElementType(QualType(AT, 0));
 
-    if (Size < CharUnits::fromQuantity(16) || EltTy->getAs<RecordType>()) {
+    if (Size < CharUnits::fromQuantity(16) ||
+        EltTy->getAsCanonical<RecordType>()) {
       GenFuncBaseTy::visitArray(FK, AT, IsVolatile, FD, CurStructOffset, Addrs);
       return;
     }
diff --git a/clang/lib/CodeGen/CGObjC.cpp b/clang/lib/CodeGen/CGObjC.cpp
index b5f17b812222a..b01d5471a9836 100644
--- a/clang/lib/CodeGen/CGObjC.cpp
+++ b/clang/lib/CodeGen/CGObjC.cpp
@@ -999,10 +999,8 @@ PropertyImplStrategy::PropertyImplStrategy(CodeGenModule &CGM,
 
   // Compute whether the ivar has strong members.
   if (CGM.getLangOpts().getGC())
-    if (const RecordType *recordType = ivarType->getAs<RecordType>())
-      HasStrong = recordType->getOriginalDecl()
-                      ->getDefinitionOrSelf()
-                      ->hasObjectMember();
+    if (const auto *RD = ivarType->getAsRecordDecl())
+      HasStrong = RD->hasObjectMember();
 
   // We can never access structs with object members with a native
   // access, because we need to use write barriers.  This is what
diff --git a/clang/lib/CodeGen/CGObjCMac.cpp b/clang/lib/CodeGen/CGObjCMac.cpp
index eb4904050ae0f..60f30a1334d6d 100644
--- a/clang/lib/CodeGen/CGObjCMac.cpp
+++ b/clang/lib/CodeGen/CGObjCMac.cpp
@@ -2315,7 +2315,7 @@ void IvarLayoutBuilder::visitBlock(const CGBlockInfo &blockInfo) {
     }
 
     assert(!type->isArrayType() && "array variable should not be caught");
-    if (const RecordType *record = type->getAs<RecordType>()) {
+    if (const RecordType *record = type->getAsCanonical<RecordType>()) {
       visitRecord(record, fieldOffset);
       continue;
     }
@@ -2409,7 +2409,7 @@ void CGObjCCommonMac::BuildRCRecordLayout(const llvm::StructLayout *RecLayout,
       if (FQT->isUnionType())
         HasUnion = true;
 
-      BuildRCBlockVarRecordLayout(FQT->castAs<RecordType>(),
+      BuildRCBlockVarRecordLayout(FQT->castAsCanonical<RecordType>(),
                                   BytePos + FieldOffset, HasUnion);
       continue;
     }
@@ -2426,7 +2426,7 @@ void CGObjCCommonMac::BuildRCRecordLayout(const llvm::StructLayout *RecLayout,
       }
       if (FQT->isRecordType() && ElCount) {
         int OldIndex = RunSkipBlockVars.size() - 1;
-        auto *RT = FQT->castAs<RecordType>();
+        auto *RT = FQT->castAsCanonical<RecordType>();
         BuildRCBlockVarRecordLayout(RT, BytePos + FieldOffset, HasUnion);
 
         // Replicate layout information for each array element. Note that
@@ -2831,7 +2831,7 @@ void CGObjCCommonMac::fillRunSkipBlockVars(CodeGenModule &CGM,
 
     assert(!type->isArrayType() && "array variable should not be caught");
     if (!CI.isByRef())
-      if (const RecordType *record = type->getAs<RecordType>()) {
+      if (const auto *record = type->getAsCanonical<RecordType>()) {
         BuildRCBlockVarRecordLayout(record, fieldOffset, hasUnion);
         continue;
       }
@@ -2865,7 +2865,7 @@ llvm::Constant *CGObjCCommonMac::BuildByrefLayout(CodeGen::CodeGenModule &CGM,
   CharUnits fieldOffset;
   RunSkipBlockVars.clear();
   bool hasUnion = false;
-  if (const RecordType *record = T->getAs<RecordType>()) {
+  if (const auto *record = T->getAsCanonical<RecordType>()) {
     BuildRCBlockVarRecordLayout(record, fieldOffset, hasUnion,
                                 true /*ByrefLayout */);
     llvm::Constant *Result = getBitmapBlockLayout(true);
@@ -3353,9 +3353,8 @@ static bool hasWeakMember(QualType type) {
     return true;
   }
 
-  if (auto recType = type->getAs<RecordType>()) {
-    for (auto *field :
-         recType->getOriginalDecl()->getDefinitionOrSelf()->fields()) {
+  if (auto *RD = type->getAsRecordDecl()) {
+    for (auto *field : RD->fields()) {
       if (hasWeakMember(field->getType()))
         return true;
     }
@@ -5247,7 +5246,7 @@ void IvarLayoutBuilder::visitField(const FieldDecl *field,
     return;
 
   // Recurse if the base element type is a record type.
-  if (auto recType = fieldType->getAs<RecordType>()) {
+  if (const auto *recType = fieldType->getAsCanonical<RecordType>()) {
     size_t oldEnd = IvarsInfo.size();
 
     visitRecord(recType, fieldOffset);
diff --git a/clang/lib/CodeGen/CGObjCRuntime.cpp b/clang/lib/CodeGen/CGObjCRuntime.cpp
index cbf99534d2ce6..76e0054f4c9da 100644
--- a/clang/lib/CodeGen/CGObjCRuntime.cpp
+++ b/clang/lib/CodeGen/CGObjCRuntime.cpp
@@ -439,10 +439,8 @@ void CGObjCRuntime::destroyCalleeDestroyedArguments(CodeGenFunction &CGF,
       CGF.EmitARCRelease(RV.getScalarVal(), ARCImpreciseLifetime);
     } else {
       QualType QT = param->getType();
-      auto *RT = QT->getAs<RecordType>();
-      if (RT && RT->getOriginalDecl()
-                    ->getDefinitionOrSelf()
-                    ->isParamDestroyedInCallee()) {
+      auto *RD = QT->getAsRecordDecl();
+      if (RD && RD->isParamDestroyedInCallee()) {
         RValue RV = I->getRValue(CGF);
         QualType::DestructionKind DtorKind = QT.isDestructedType();
         switch (DtorKind) {
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index f98339d472fa9..18ce8515bbee8 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -3006,10 +3006,10 @@ emitProxyTaskFunction(CodeGenModule &CGM, SourceLocation Loc,
       CGF.GetAddrOfLocalVar(&TaskTypeArg),
       KmpTaskTWithPrivatesPtrQTy->castAs<PointerType>());
   const auto *KmpTaskTWithPrivatesQTyRD =
-      cast<RecordDecl>(KmpTaskTWithPrivatesQTy->getAsTagDecl());
+      KmpTaskTWithPrivatesQTy->castAsRecordDecl();
   LValue Base =
       CGF.EmitLValueForField(TDBase, *KmpTaskTWithPrivatesQTyRD->field_begin());
-  const auto *KmpTaskTQTyRD = cast<RecordDecl>(KmpTaskTQTy->getAsTagDecl());
+  const auto *KmpTaskTQTyRD = KmpTaskTQTy->castAsRecordDecl();
   auto PartIdFI = std::next(KmpTaskTQTyRD->field_begin(), KmpTaskTPartId);
   LValue PartIdLVal = CGF.EmitLValueForField(Base, *PartIdFI);
   llvm::Value *PartidParam = PartIdLVal.getPointer(CGF);
@@ -3104,11 +3104,10 @@ static llvm::Value *emitDestructorsFunction(CodeGenModule &CGM,
       CGF.GetAddrOfLocalVar(&TaskTypeArg),
       KmpTaskTWithPrivatesPtrQTy->castAs<PointerType>());
   const auto *KmpTaskTWithPrivatesQTyRD =
-      cast<RecordDecl>(KmpTaskTWithPrivatesQTy->getAsTagDecl());
+      KmpTaskTWithPrivatesQTy->castAsRecordDecl();
   auto FI = std::next(KmpTaskTWithPrivatesQTyRD->field_begin());
   Base = CGF.EmitLValueForField(Base, *FI);
-  for (const auto *Field :
-       cast<RecordDecl>(FI->getType()->getAsTagDecl())->fields()) {
+  for (const auto *Field : FI->getType()->castAsRecordDecl()->fields()) {
     if (QualType::DestructionKind DtorKind =
             Field->getType().isDestructedType()) {
       LValue FieldLValue = CGF.EmitLValueForField(Base, Field);
@@ -3212,7 +3211,7 @@ emitTaskPrivateMappingFunction(CodeGenModule &CGM, SourceLocation Loc,
   LValue Base = CGF.EmitLoadOfPointerLValue(
       CGF.GetAddrOfLocalVar(&TaskPrivatesArg),
       TaskPrivatesArg.getType()->castAs<PointerType>());
-  const auto *PrivatesQTyRD = cast<RecordDecl>(PrivatesQTy->getAsTagDecl());
+  const auto *PrivatesQTyRD = PrivatesQTy->castAsRecordDecl();
   Counter = 0;
   for (const FieldDecl *Field : PrivatesQTyRD->fields()) {
     LValue FieldLVal = CGF.EmitLValueForField(Base, Field);
@@ -3259,7 +3258,7 @@ static void emitPrivatesInit(CodeGenFunction &CGF,
             CGF.ConvertTypeForMem(SharedsTy)),
         SharedsTy);
   }
-  FI = cast<RecordDecl>(FI->getType()->getAsTagDecl())->field_begin();
+  FI = FI->getType()->castAsRecordDecl()->field_begin();
   for (const PrivateDataTy &Pair : Privates) {
     // Do not initialize private locals.
     if (Pair.second.isLocalPrivate()) {
@@ -3655,7 +3654,7 @@ CGOpenMPRuntime::emitTaskInit(CodeGenFunction &CGF, SourceLocation Loc,
     }
     KmpTaskTQTy = SavedKmpTaskTQTy;
   }
-  const auto *KmpTaskTQTyRD = cast<RecordDecl>(KmpTaskTQTy->getAsTagDecl());
+  const auto *KmpTaskTQTyRD = KmpTaskTQTy->castAsRecordDecl();
   // Build particular struct kmp_task_t for the given task.
   const RecordDecl *KmpTaskTWithPrivatesQTyRD =
       createKmpTaskTWithPrivatesRecordDecl(CGM, KmpTaskTQTy, Privates);
@@ -4032,8 +4031,7 @@ CGOpenMPRuntime::getDepobjElements(CodeGenFunction &CGF, LValue DepobjLVal,
   ASTContext &C = CGM.getContext();
   QualType FlagsTy;
   getDependTypes(C, KmpDependInfoTy, FlagsTy);
-  RecordDecl *KmpDependInfoRD =
-      cast<RecordDecl>(KmpDependInfoTy->getAsTagDecl());
+  auto *KmpDependInfoRD = KmpDependInfoTy->castAsRecordDecl();
   QualType KmpDependInfoPtrTy = C.getPointerType(KmpDependInfoTy);
   LValue Base = CGF.EmitLoadOfPointerLValue(
       DepobjLVal.getAddress().withElementType(
@@ -4061,8 +4059,7 @@ static void emitDependData(CodeGenFunction &CGF, QualType &KmpDependInfoTy,
   ASTContext &C = CGM.getContext();
   QualType FlagsTy;
   getDependTypes(C, KmpDependInfoTy, FlagsTy);
-  RecordDecl *KmpDependInfoRD =
-      cast<RecordDecl>(KmpDependInfoTy->getAsTagDecl());
+  auto *KmpDependInfoRD = KmpDependInfoTy->castAsRecordDecl();
   llvm::Type *LLVMFlagsTy = CGF.ConvertTypeForMem(FlagsTy);
 
   OMPIteratorGeneratorScope IteratorScope(
@@ -4333,8 +4330,7 @@ Address CGOpenMPRuntime::emitDepobjDependClause(
   unsigned NumDependencies = Dependencies.DepExprs.size();
   QualType FlagsTy;
   getDependTypes(C, KmpDependInfoTy, FlagsTy);
-  RecordDecl *KmpDependInfoRD =
-      cast<RecordDecl>(KmpDependInfoTy->getAsTagDecl());
+  auto *KmpDependInfoRD = KmpDependInfoTy->castAsRecordDecl();
 
   llvm::Value *Size;
   // Define type kmp_depend_info[<Dependencies.size()>];
@@ -4442,8 +4438,7 @@ void CGOpenMPRuntime::emitUpdateClause(CodeGenFunction &CGF, LValue DepobjLVal,
   ASTContext &C = CGM.getContext();
   QualType FlagsTy;
   getDependTypes(C, KmpDependInfoTy, FlagsTy);
-  RecordDecl *KmpDependInfoRD =
-      cast<RecordDecl>(KmpDependInfoTy->getAsTagDecl());
+  auto *KmpDependInfoRD = KmpDependInfoTy->castAsRecordDecl();
   llvm::Type *LLVMFlagsTy = CGF.ConvertTypeForMem(FlagsTy);
   llvm::Value *NumDeps;
   LValue Base;
@@ -11319,7 +11314,7 @@ void CGOpenMPRuntime::emitDoacrossInit(CodeGenFunction &CGF,
     RD->completeDefinition();
     KmpDimTy = C.getCanonicalTagType(RD);
   } else {
-    RD = cast<RecordDecl>(KmpDimTy->getAsTagDecl());
+    RD = KmpDimTy->castAsRecordDecl();
   }
   llvm::APInt Size(/*numBits=*/32, NumIterations.size());
   QualType ArrayTy = C.getConstantArrayType(KmpDimTy, Size, nullptr,
diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp
index 652fe672f15e3..23e179881fccb 100644
--- a/clang/lib/CodeGen/CodeGenFunction.cpp
+++ b/clang/lib/CodeGen/CodeGenFunction.cpp
@@ -2221,14 +2221,9 @@ static void emitNonZeroVLAInit(CodeGenFunction &CGF, QualType baseType,
 void
 CodeGenFunction::EmitNullInitialization(Address DestPtr, QualType Ty) {
   // Ignore empty classes in C++.
-  if (getLangOpts().CPlusPlus) {
-    if (const RecordType *RT = Ty->getAs<RecordType>()) {
-      if (cast<CXXRecordDecl>(RT->getOriginalDecl())
-              ->getDefinitionOrSelf()
-              ->isEmpty())
-        return;
-    }
-  }
+  if (getLangOpts().CPlusPlus)
+    if (const auto *RD = Ty->getAsCXXRecordDecl(); RD && RD->isEmpty())
+      return;
 
   if (DestPtr.getElementType() != Int8Ty)
     DestPtr = DestPtr.withElementType(Int8Ty);
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index fc65199a0f154..a562a6a1ea6e1 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -2979,10 +2979,8 @@ class CodeGenFunction : public CodeGenTypeCache {
   /// hasVolatileMember - returns true if aggregate type has a volatile
   /// member.
   bool hasVolatileMember(QualType T) {
-    if (const RecordType *RT = T->getAs<RecordType>()) {
-      const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
+    if (const auto *RD = T->getAsRecordDecl())
       return RD->hasVolatileMember();
-    }
     return false;
   }
 
diff --git a/clang/lib/CodeGen/CodeGenModule.cpp b/clang/lib/CodeGen/CodeGenModule.cpp
index 677d8bc82cb0a..094221fed1aed 100644
--- a/clang/lib/CodeGen/CodeGenModule.cpp
+++ b/clang/lib/CodeGen/CodeGenModule.cpp
@@ -4195,7 +4195,8 @@ void CodeGenModule::EmitGlobal(GlobalDecl GD) {
 
 // Check if T is a class type with a destructor that's not dllimport.
 static bool HasNonDllImportDtor(QualType T) {
-  if (const auto *RT = T->getBaseElementTypeUnsafe()->getAs<RecordType>())
+  if (const auto *RT =
+          T->getBaseElementTypeUnsafe()->getAsCanonical<RecordType>())
     if (auto *RD = dyn_cast<CXXRecordDecl>(RT->getOriginalDecl())) {
       RD = RD->getDefinitionOrSelf();
       if (RD->getDestructor() && !RD->getDestructor()->hasAttr<DLLImportAttr>())
@@ -6075,8 +6076,7 @@ static bool isVarDeclStrongDefinition(const ASTContext &Context,
     if (Context.isAlignmentRequired(VarType))
       return true;
 
-    if (const auto *RT = VarType->getAs<RecordType>()) {
-      const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
+    if (const auto *RD = VarType->getAsRecordDecl()) {
       for (const FieldDecl *FD : RD->fields()) {
         if (FD->isBitField())
           continue;
diff --git a/clang/lib/CodeGen/CodeGenTBAA.cpp b/clang/lib/CodeGen/CodeGenTBAA.cpp
index bd2442f01cc50..f8c7d64cc1aa2 100644
--- a/clang/lib/CodeGen/CodeGenTBAA.cpp
+++ b/clang/lib/CodeGen/CodeGenTBAA.cpp
@@ -142,10 +142,9 @@ static bool TypeHasMayAlias(QualType QTy) {
 
 /// Check if the given type is a valid base type to be used in access tags.
 static bool isValidBaseType(QualType QTy) {
-  if (const RecordType *TTy = QTy->getAs<RecordType>()) {
-    const RecordDecl *RD = TTy->getOriginalDecl()->getDefinition();
+  if (const auto *RD = QTy->getAsRecordDecl()) {
     // Incomplete types are not valid base access types.
-    if (!RD)
+    if (!RD->isCompleteDefinition())
       return false;
     if (RD->hasFlexibleArrayMember())
       return false;
@@ -296,7 +295,7 @@ llvm::MDNode *CodeGenTBAA::getTypeInfoHelper(const Type *Ty) {
       // Be conservative if the type isn't a RecordType. We are specifically
       // required to do this for member pointers until we implement the
       // similar-types rule.
-      const auto *RT = Ty->getAs<RecordType>();
+      const auto *RT = Ty->getAsCanonical<RecordType>();
       if (!RT)
         return getAnyPtr(PtrDepth);
 
@@ -425,7 +424,7 @@ CodeGenTBAA::CollectFields(uint64_t BaseOffset,
                            bool MayAlias) {
   /* Things not handled yet include: C++ base classes, bitfields, */
 
-  if (const RecordType *TTy = QTy->getAs<RecordType>()) {
+  if (const auto *TTy = QTy->getAsCanonical<RecordType>()) {
     if (TTy->isUnionType()) {
       uint64_t Size = Context.getTypeSizeInChars(QTy).getQuantity();
       llvm::MDNode *TBAAType = getChar();
diff --git a/clang/lib/CodeGen/CodeGenTypes.cpp b/clang/lib/CodeGen/CodeGenTypes.cpp
index f2a0a649a88fd..3ffe999d01178 100644
--- a/clang/lib/CodeGen/CodeGenTypes.cpp
+++ b/clang/lib/CodeGen/CodeGenTypes.cpp
@@ -311,12 +311,12 @@ llvm::Type *CodeGenTypes::ConvertFunctionTypeInternal(QualType QFT) {
 
     // Force conversion of all the relevant record types, to make sure
     // we re-convert the FunctionType when appropriate.
-    if (const RecordType *RT = FT->getReturnType()->getAs<RecordType>())
-      ConvertRecordDeclType(RT->getOriginalDecl()->getDefinitionOrSelf());
+    if (const auto *RD = FT->getReturnType()->getAsRecordDecl())
+      ConvertRecordDeclType(RD);
     if (const FunctionProtoType *FPT = dyn_cast<FunctionProtoType>(FT))
       for (unsigned i = 0, e = FPT->getNumParams(); i != e; i++)
-        if (const RecordType *RT = FPT->getParamType(i)->getAs<RecordType>())
-          ConvertRecordDeclType(RT->getOriginalDecl()->getDefinitionOrSelf());
+        if (const auto *RD = FPT->getParamType(i)->getAsRecordDecl())
+          ConvertRecordDeclType(RD);
 
     SkippedLayout = true;
 
@@ -700,8 +700,7 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
     break;
 
   case Type::Enum: {
-    const EnumDecl *ED =
-        cast<EnumType>(Ty)->getOriginalDecl()->getDefinitionOrSelf();
+    const auto *ED = Ty->castAsEnumDecl();
     if (ED->isCompleteDefinition() || ED->isFixed())
       return ConvertType(ED->getIntegerType());
     // Return a placeholder 'i32' type.  This can be changed later when the
@@ -814,10 +813,7 @@ llvm::StructType *CodeGenTypes::ConvertRecordDeclType(const RecordDecl *RD) {
   if (const CXXRecordDecl *CRD = dyn_cast<CXXRecordDecl>(RD)) {
     for (const auto &I : CRD->bases()) {
       if (I.isVirtual()) continue;
-      ConvertRecordDeclType(I.getType()
-                                ->castAs<RecordType>()
-                                ->getOriginalDecl()
-                                ->getDefinitionOrSelf());
+      ConvertRecordDeclType(I.getType()->castAsRecordDecl());
     }
   }
 
@@ -875,10 +871,8 @@ bool CodeGenTypes::isZeroInitializable(QualType T) {
 
   // Records are non-zero-initializable if they contain any
   // non-zero-initializable subobjects.
-  if (const RecordType *RT = T->getAs<RecordType>()) {
-    const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
+  if (const auto *RD = T->getAsRecordDecl())
     return isZeroInitializable(RD);
-  }
 
   // We have to ask the ABI about member pointers.
   if (const MemberPointerType *MPT = T->getAs<MemberPointerType>())
diff --git a/clang/lib/CodeGen/HLSLBufferLayoutBuilder.cpp b/clang/lib/CodeGen/HLSLBufferLayoutBuilder.cpp
index ac56dda74abb7..a21feaa1120c6 100644
--- a/clang/lib/CodeGen/HLSLBufferLayoutBuilder.cpp
+++ b/clang/lib/CodeGen/HLSLBufferLayoutBuilder.cpp
@@ -85,24 +85,22 @@ llvm::TargetExtType *HLSLBufferLayoutBuilder::createLayoutType(
   Layout.push_back(0);
 
   // iterate over all fields of the record, including fields on base classes
-  llvm::SmallVector<const RecordType *> RecordTypes;
-  RecordTypes.push_back(RT);
-  while (RecordTypes.back()->getAsCXXRecordDecl()->getNumBases()) {
-    CXXRecordDecl *D = RecordTypes.back()->getAsCXXRecordDecl();
+  llvm::SmallVector<CXXRecordDecl *> RecordDecls;
+  RecordDecls.push_back(RT->castAsCXXRecordDecl());
+  while (RecordDecls.back()->getNumBases()) {
+    CXXRecordDecl *D = RecordDecls.back();
     assert(D->getNumBases() == 1 &&
            "HLSL doesn't support multiple inheritance");
-    RecordTypes.push_back(D->bases_begin()->getType()->getAs<RecordType>());
+    RecordDecls.push_back(D->bases_begin()->getType()->castAsCXXRecordDecl());
   }
 
   unsigned FieldOffset;
   llvm::Type *FieldType;
 
-  while (!RecordTypes.empty()) {
-    const RecordType *RT = RecordTypes.back();
-    RecordTypes.pop_back();
+  while (!RecordDecls.empty()) {
+    const CXXRecordDecl *RD = RecordDecls.pop_back_val();
 
-    for (const auto *FD :
-         RT->getOriginalDecl()->getDefinitionOrSelf()->fields()) {
+    for (const auto *FD : RD->fields()) {
       assert((!PackOffsets || Index < PackOffsets->size()) &&
              "number of elements in layout struct does not match number of "
              "packoffset annotations");
@@ -148,7 +146,7 @@ llvm::TargetExtType *HLSLBufferLayoutBuilder::createLayoutType(
 
   // create the layout struct type; anonymous struct have empty name but
   // non-empty qualified name
-  const CXXRecordDecl *Decl = RT->getAsCXXRecordDecl();
+  const auto *Decl = RT->castAsCXXRecordDecl();
   std::string Name =
       Decl->getName().empty() ? "anon" : Decl->getQualifiedNameAsString();
   llvm::StructType *StructTy =
@@ -203,8 +201,8 @@ bool HLSLBufferLayoutBuilder::layoutField(const FieldDecl *FD,
     // For array of structures, create a new array with a layout type
     // instead of the structure type.
     if (Ty->isStructureOrClassType()) {
-      llvm::Type *NewTy =
-          cast<llvm::TargetExtType>(createLayoutType(Ty->getAs<RecordType>()));
+      llvm::Type *NewTy = cast<llvm::TargetExtType>(
+          createLayoutType(Ty->getAsCanonical<RecordType>()));
       if (!NewTy)
         return false;
       assert(isa<llvm::TargetExtType>(NewTy) && "expected target type");
@@ -222,8 +220,8 @@ bool HLSLBufferLayoutBuilder::layoutField(const FieldDecl *FD,
 
   } else if (FieldTy->isStructureOrClassType()) {
     // Create a layout type for the structure
-    ElemLayoutTy =
-        createLayoutType(cast<RecordType>(FieldTy->getAs<RecordType>()));
+    ElemLayoutTy = createLayoutType(
+        cast<RecordType>(FieldTy->getAsCanonical<RecordType>()));
     if (!ElemLayoutTy)
       return false;
     assert(isa<llvm::TargetExtType>(ElemLayoutTy) && "expected target type");
diff --git a/clang/lib/CodeGen/ItaniumCXXABI.cpp b/clang/lib/CodeGen/ItaniumCXXABI.cpp
index 4ed3775f156c9..885b700ffa193 100644
--- a/clang/lib/CodeGen/ItaniumCXXABI.cpp
+++ b/clang/lib/CodeGen/ItaniumCXXABI.cpp
@@ -1397,9 +1397,7 @@ void ItaniumCXXABI::emitVirtualObjectDelete(CodeGenFunction &CGF,
     // to pass to the deallocation function.
 
     // Grab the vtable pointer as an intptr_t*.
-    auto *ClassDecl = cast<CXXRecordDecl>(
-                          ElementType->castAs<RecordType>()->getOriginalDecl())
-                          ->getDefinitionOrSelf();
+    auto *ClassDecl = ElementType->castAsCXXRecordDecl();
     llvm::Value *VTable = CGF.GetVTablePtr(Ptr, CGF.UnqualPtrTy, ClassDecl);
 
     // Track back to entry -2 and pull out the offset there.
@@ -1484,21 +1482,18 @@ void ItaniumCXXABI::emitThrow(CodeGenFunction &CGF, const CXXThrowExpr *E) {
   // The address of the destructor.  If the exception type has a
   // trivial destructor (or isn't a record), we just pass null.
   llvm::Constant *Dtor = nullptr;
-  if (const RecordType *RecordTy = ThrowType->getAs<RecordType>()) {
-    CXXRecordDecl *Record =
-        cast<CXXRecordDecl>(RecordTy->getOriginalDecl())->getDefinitionOrSelf();
-    if (!Record->hasTrivialDestructor()) {
-      // __cxa_throw is declared to take its destructor as void (*)(void *). We
-      // must match that if function pointers can be authenticated with a
-      // discriminator based on their type.
-      const ASTContext &Ctx = getContext();
-      QualType DtorTy = Ctx.getFunctionType(Ctx.VoidTy, {Ctx.VoidPtrTy},
-                                            FunctionProtoType::ExtProtoInfo());
-
-      CXXDestructorDecl *DtorD = Record->getDestructor();
-      Dtor = CGM.getAddrOfCXXStructor(GlobalDecl(DtorD, Dtor_Complete));
-      Dtor = CGM.getFunctionPointer(Dtor, DtorTy);
-    }
+  if (const auto *Record = ThrowType->getAsCXXRecordDecl();
+      Record && !Record->hasTrivialDestructor()) {
+    // __cxa_throw is declared to take its destructor as void (*)(void *). We
+    // must match that if function pointers can be authenticated with a
+    // discriminator based on their type.
+    const ASTContext &Ctx = getContext();
+    QualType DtorTy = Ctx.getFunctionType(Ctx.VoidTy, {Ctx.VoidPtrTy},
+                                          FunctionProtoType::ExtProtoInfo());
+
+    CXXDestructorDecl *DtorD = Record->getDestructor();
+    Dtor = CGM.getAddrOfCXXStructor(GlobalDecl(DtorD, Dtor_Complete));
+    Dtor = CGM.getFunctionPointer(Dtor, DtorTy);
   }
   if (!Dtor) Dtor = llvm::Constant::getNullValue(CGM.Int8PtrTy);
 
@@ -1612,9 +1607,7 @@ llvm::Value *ItaniumCXXABI::EmitTypeid(CodeGenFunction &CGF,
                                        QualType SrcRecordTy,
                                        Address ThisPtr,
                                        llvm::Type *StdTypeInfoPtrTy) {
-  auto *ClassDecl =
-      cast<CXXRecordDecl>(SrcRecordTy->castAs<RecordType>()->getOriginalDecl())
-          ->getDefinitionOrSelf();
+  auto *ClassDecl = SrcRecordTy->castAsCXXRecordDecl();
   llvm::Value *Value = CGF.GetVTablePtr(ThisPtr, CGM.GlobalsInt8PtrTy,
                                         ClassDecl);
 
@@ -1786,9 +1779,7 @@ llvm::Value *ItaniumCXXABI::emitExactDynamicCast(
 llvm::Value *ItaniumCXXABI::emitDynamicCastToVoid(CodeGenFunction &CGF,
                                                   Address ThisAddr,
                                                   QualType SrcRecordTy) {
-  auto *ClassDecl =
-      cast<CXXRecordDecl>(SrcRecordTy->castAs<RecordType>()->getOriginalDecl())
-          ->getDefinitionOrSelf();
+  auto *ClassDecl = SrcRecordTy->castAsCXXRecordDecl();
   llvm::Value *OffsetToTop;
   if (CGM.getItaniumVTableContext().isRelativeLayout()) {
     // Get the vtable pointer.
@@ -3872,9 +3863,7 @@ static bool CanUseSingleInheritance(const CXXRecordDecl *RD) {
     return false;
 
   // Check that the class is dynamic iff the base is.
-  auto *BaseDecl = cast<CXXRecordDecl>(
-                       Base->getType()->castAs<RecordType>()->getOriginalDecl())
-                       ->getDefinitionOrSelf();
+  auto *BaseDecl = Base->getType()->castAsCXXRecordDecl();
   if (!BaseDecl->isEmpty() &&
       BaseDecl->isDynamicClass() != RD->isDynamicClass())
     return false;
@@ -4394,10 +4383,7 @@ static unsigned ComputeVMIClassTypeInfoFlags(const CXXBaseSpecifier *Base,
 
   unsigned Flags = 0;
 
-  auto *BaseDecl = cast<CXXRecordDecl>(
-                       Base->getType()->castAs<RecordType>()->getOriginalDecl())
-                       ->getDefinitionOrSelf();
-
+  auto *BaseDecl = Base->getType()->castAsCXXRecordDecl();
   if (Base->isVirtual()) {
     // Mark the virtual base as seen.
     if (!Bases.VirtualBases.insert(BaseDecl).second) {
@@ -4495,11 +4481,7 @@ void ItaniumRTTIBuilder::BuildVMIClassTypeInfo(const CXXRecordDecl *RD) {
     // The __base_type member points to the RTTI for the base type.
     Fields.push_back(ItaniumRTTIBuilder(CXXABI).BuildTypeInfo(Base.getType()));
 
-    auto *BaseDecl =
-        cast<CXXRecordDecl>(
-            Base.getType()->castAs<RecordType>()->getOriginalDecl())
-            ->getDefinitionOrSelf();
-
+    auto *BaseDecl = Base.getType()->castAsCXXRecordDecl();
     int64_t OffsetFlags = 0;
 
     // All but the lower 8 bits of __offset_flags are a signed offset.
diff --git a/clang/lib/CodeGen/SwiftCallingConv.cpp b/clang/lib/CodeGen/SwiftCallingConv.cpp
index de58e0d95a866..4d894fd99db05 100644
--- a/clang/lib/CodeGen/SwiftCallingConv.cpp
+++ b/clang/lib/CodeGen/SwiftCallingConv.cpp
@@ -65,7 +65,7 @@ void SwiftAggLowering::addTypedData(QualType type, CharUnits begin) {
   // Deal with various aggregate types as special cases:
 
   // Record types.
-  if (auto recType = type->getAs<RecordType>()) {
+  if (auto recType = type->getAsCanonical<RecordType>()) {
     addTypedData(recType->getOriginalDecl(), begin);
 
     // Array types.
diff --git a/clang/lib/CodeGen/Targets/AArch64.cpp b/clang/lib/CodeGen/Targets/AArch64.cpp
index 289f8a9dcf211..d7deece232a9f 100644
--- a/clang/lib/CodeGen/Targets/AArch64.cpp
+++ b/clang/lib/CodeGen/Targets/AArch64.cpp
@@ -374,8 +374,8 @@ ABIArgInfo AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadicFn,
 
   if (!passAsAggregateType(Ty)) {
     // Treat an enum type as its underlying type.
-    if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-      Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+    if (const auto *ED = Ty->getAsEnumDecl())
+      Ty = ED->getIntegerType();
 
     if (const auto *EIT = Ty->getAs<BitIntType>())
       if (EIT->getNumBits() > 128)
@@ -493,10 +493,9 @@ ABIArgInfo AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadicFn,
     auto ContainsOnlyPointers = [&](const auto &Self, QualType Ty) {
       if (isEmptyRecord(getContext(), Ty, true))
         return false;
-      const RecordType *RT = Ty->getAs<RecordType>();
-      if (!RT)
+      const auto *RD = Ty->getAsRecordDecl();
+      if (!RD)
         return false;
-      const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
       if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
         for (const auto &I : CXXRD->bases())
           if (!Self(Self, I.getType()))
@@ -547,9 +546,8 @@ ABIArgInfo AArch64ABIInfo::classifyReturnType(QualType RetTy,
 
   if (!passAsAggregateType(RetTy)) {
     // Treat an enum type as its underlying type.
-    if (const EnumType *EnumTy = RetTy->getAs<EnumType>())
-      RetTy =
-          EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+    if (const auto *ED = RetTy->getAsEnumDecl())
+      RetTy = ED->getIntegerType();
 
     if (const auto *EIT = RetTy->getAs<BitIntType>())
       if (EIT->getNumBits() > 128)
@@ -738,7 +736,7 @@ bool AArch64ABIInfo::passAsPureScalableType(
     return true;
   }
 
-  if (const RecordType *RT = Ty->getAs<RecordType>()) {
+  if (const RecordType *RT = Ty->getAsCanonical<RecordType>()) {
     // If the record cannot be passed in registers, then it's not a PST.
     if (CGCXXABI::RecordArgABI RAA = getRecordArgABI(RT, getCXXABI());
         RAA != CGCXXABI::RAA_Default)
diff --git a/clang/lib/CodeGen/Targets/AMDGPU.cpp b/clang/lib/CodeGen/Targets/AMDGPU.cpp
index 41bccbb5721b2..0fcbf7e458a34 100644
--- a/clang/lib/CodeGen/Targets/AMDGPU.cpp
+++ b/clang/lib/CodeGen/Targets/AMDGPU.cpp
@@ -95,8 +95,7 @@ unsigned AMDGPUABIInfo::numRegsForType(QualType Ty) const {
     return EltNumRegs * VT->getNumElements();
   }
 
-  if (const RecordType *RT = Ty->getAs<RecordType>()) {
-    const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
+  if (const auto *RD = Ty->getAsRecordDecl()) {
     assert(!RD->hasFlexibleArrayMember());
 
     for (const FieldDecl *Field : RD->fields()) {
@@ -152,11 +151,9 @@ ABIArgInfo AMDGPUABIInfo::classifyReturnType(QualType RetTy) const {
       if (const Type *SeltTy = isSingleElementStruct(RetTy, getContext()))
         return ABIArgInfo::getDirect(CGT.ConvertType(QualType(SeltTy, 0)));
 
-      if (const RecordType *RT = RetTy->getAs<RecordType>()) {
-        const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
-        if (RD->hasFlexibleArrayMember())
-          return DefaultABIInfo::classifyReturnType(RetTy);
-      }
+      if (const auto *RD = RetTy->getAsRecordDecl();
+          RD && RD->hasFlexibleArrayMember())
+        return DefaultABIInfo::classifyReturnType(RetTy);
 
       // Pack aggregates <= 4 bytes into single VGPR or pair.
       uint64_t Size = getContext().getTypeSize(RetTy);
@@ -245,11 +242,9 @@ ABIArgInfo AMDGPUABIInfo::classifyArgumentType(QualType Ty, bool Variadic,
     if (const Type *SeltTy = isSingleElementStruct(Ty, getContext()))
       return ABIArgInfo::getDirect(CGT.ConvertType(QualType(SeltTy, 0)));
 
-    if (const RecordType *RT = Ty->getAs<RecordType>()) {
-      const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
-      if (RD->hasFlexibleArrayMember())
-        return DefaultABIInfo::classifyArgumentType(Ty);
-    }
+    if (const auto *RD = Ty->getAsRecordDecl();
+        RD && RD->hasFlexibleArrayMember())
+      return DefaultABIInfo::classifyArgumentType(Ty);
 
     // Pack aggregates <= 8 bytes into single VGPR or pair.
     uint64_t Size = getContext().getTypeSize(Ty);
diff --git a/clang/lib/CodeGen/Targets/ARC.cpp b/clang/lib/CodeGen/Targets/ARC.cpp
index ace524e1976d9..67275877cbd9e 100644
--- a/clang/lib/CodeGen/Targets/ARC.cpp
+++ b/clang/lib/CodeGen/Targets/ARC.cpp
@@ -94,7 +94,7 @@ RValue ARCABIInfo::EmitVAArg(CodeGenFunction &CGF, Address VAListAddr,
 ABIArgInfo ARCABIInfo::classifyArgumentType(QualType Ty,
                                             uint8_t FreeRegs) const {
   // Handle the generic C++ ABI.
-  const RecordType *RT = Ty->getAs<RecordType>();
+  const RecordType *RT = Ty->getAsCanonical<RecordType>();
   if (RT) {
     CGCXXABI::RecordArgABI RAA = getRecordArgABI(RT, getCXXABI());
     if (RAA == CGCXXABI::RAA_Indirect)
@@ -105,8 +105,8 @@ ABIArgInfo ARCABIInfo::classifyArgumentType(QualType Ty,
   }
 
   // Treat an enum type as its underlying type.
-  if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-    Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = Ty->getAsEnumDecl())
+    Ty = ED->getIntegerType();
 
   auto SizeInRegs = llvm::alignTo(getContext().getTypeSize(Ty), 32) / 32;
 
diff --git a/clang/lib/CodeGen/Targets/ARM.cpp b/clang/lib/CodeGen/Targets/ARM.cpp
index 3739e16788c3e..c84c9f2f643ee 100644
--- a/clang/lib/CodeGen/Targets/ARM.cpp
+++ b/clang/lib/CodeGen/Targets/ARM.cpp
@@ -382,8 +382,8 @@ ABIArgInfo ARMABIInfo::classifyArgumentType(QualType Ty, bool isVariadic,
 
   if (!isAggregateTypeForABI(Ty)) {
     // Treat an enum type as its underlying type.
-    if (const EnumType *EnumTy = Ty->getAs<EnumType>()) {
-      Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+    if (const auto *ED = Ty->getAsEnumDecl()) {
+      Ty = ED->getIntegerType();
     }
 
     if (const auto *EIT = Ty->getAs<BitIntType>())
@@ -512,7 +512,7 @@ static bool isIntegerLikeType(QualType Ty, ASTContext &Context,
   // above, but they are not.
 
   // Otherwise, it must be a record type.
-  const RecordType *RT = Ty->getAs<RecordType>();
+  const RecordType *RT = Ty->getAsCanonical<RecordType>();
   if (!RT) return false;
 
   // Ignore records with flexible arrays.
@@ -592,9 +592,8 @@ ABIArgInfo ARMABIInfo::classifyReturnType(QualType RetTy, bool isVariadic,
 
   if (!isAggregateTypeForABI(RetTy)) {
     // Treat an enum type as its underlying type.
-    if (const EnumType *EnumTy = RetTy->getAs<EnumType>())
-      RetTy =
-          EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+    if (const auto *ED = RetTy->getAsEnumDecl())
+      RetTy = ED->getIntegerType();
 
     if (const auto *EIT = RetTy->getAs<BitIntType>())
       if (EIT->getNumBits() > 64)
@@ -718,9 +717,8 @@ bool ARMABIInfo::containsAnyFP16Vectors(QualType Ty) const {
     if (NElements == 0)
       return false;
     return containsAnyFP16Vectors(AT->getElementType());
-  } else if (const RecordType *RT = Ty->getAs<RecordType>()) {
-    const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
-
+  }
+  if (const auto *RD = Ty->getAsRecordDecl()) {
     // If this is a C++ record, check the bases first.
     if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD))
       if (llvm::any_of(CXXRD->bases(), [this](const CXXBaseSpecifier &B) {
diff --git a/clang/lib/CodeGen/Targets/BPF.cpp b/clang/lib/CodeGen/Targets/BPF.cpp
index 87d50e671d251..3a7af346f1132 100644
--- a/clang/lib/CodeGen/Targets/BPF.cpp
+++ b/clang/lib/CodeGen/Targets/BPF.cpp
@@ -47,8 +47,8 @@ class BPFABIInfo : public DefaultABIInfo {
       }
     }
 
-    if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-      Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+    if (const auto *ED = Ty->getAsEnumDecl())
+      Ty = ED->getIntegerType();
 
     ASTContext &Context = getContext();
     if (const auto *EIT = Ty->getAs<BitIntType>())
@@ -69,9 +69,8 @@ class BPFABIInfo : public DefaultABIInfo {
                                      getDataLayout().getAllocaAddrSpace());
 
     // Treat an enum type as its underlying type.
-    if (const EnumType *EnumTy = RetTy->getAs<EnumType>())
-      RetTy =
-          EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+    if (const auto *ED = RetTy->getAsEnumDecl())
+      RetTy = ED->getIntegerType();
 
     ASTContext &Context = getContext();
     if (const auto *EIT = RetTy->getAs<BitIntType>())
diff --git a/clang/lib/CodeGen/Targets/CSKY.cpp b/clang/lib/CodeGen/Targets/CSKY.cpp
index 7e5a16f30727f..df1da8d61a259 100644
--- a/clang/lib/CodeGen/Targets/CSKY.cpp
+++ b/clang/lib/CodeGen/Targets/CSKY.cpp
@@ -115,8 +115,8 @@ ABIArgInfo CSKYABIInfo::classifyArgumentType(QualType Ty, int &ArgGPRsLeft,
 
   if (!isAggregateTypeForABI(Ty)) {
     // Treat an enum type as its underlying type.
-    if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-      Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+    if (const auto *ED = Ty->getAsEnumDecl())
+      Ty = ED->getIntegerType();
 
     // All integral types are promoted to XLen width, unless passed on the
     // stack.
diff --git a/clang/lib/CodeGen/Targets/Hexagon.cpp b/clang/lib/CodeGen/Targets/Hexagon.cpp
index 0c423429eda4f..97a9300a1b8cd 100644
--- a/clang/lib/CodeGen/Targets/Hexagon.cpp
+++ b/clang/lib/CodeGen/Targets/Hexagon.cpp
@@ -97,8 +97,8 @@ ABIArgInfo HexagonABIInfo::classifyArgumentType(QualType Ty,
                                                 unsigned *RegsLeft) const {
   if (!isAggregateTypeForABI(Ty)) {
     // Treat an enum type as its underlying type.
-    if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-      Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+    if (const auto *ED = Ty->getAsEnumDecl())
+      Ty = ED->getIntegerType();
 
     uint64_t Size = getContext().getTypeSize(Ty);
     if (Size <= 64)
@@ -160,9 +160,8 @@ ABIArgInfo HexagonABIInfo::classifyReturnType(QualType RetTy) const {
 
   if (!isAggregateTypeForABI(RetTy)) {
     // Treat an enum type as its underlying type.
-    if (const EnumType *EnumTy = RetTy->getAs<EnumType>())
-      RetTy =
-          EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+    if (const auto *ED = RetTy->getAsEnumDecl())
+      RetTy = ED->getIntegerType();
 
     if (Size > 64 && RetTy->isBitIntType())
       return getNaturalAlignIndirect(
diff --git a/clang/lib/CodeGen/Targets/Lanai.cpp b/clang/lib/CodeGen/Targets/Lanai.cpp
index 08cb36034f6fd..e76431a484e70 100644
--- a/clang/lib/CodeGen/Targets/Lanai.cpp
+++ b/clang/lib/CodeGen/Targets/Lanai.cpp
@@ -88,7 +88,7 @@ ABIArgInfo LanaiABIInfo::getIndirectResult(QualType Ty, bool ByVal,
 ABIArgInfo LanaiABIInfo::classifyArgumentType(QualType Ty,
                                               CCState &State) const {
   // Check with the C++ ABI first.
-  const RecordType *RT = Ty->getAs<RecordType>();
+  const RecordType *RT = Ty->getAsCanonical<RecordType>();
   if (RT) {
     CGCXXABI::RecordArgABI RAA = getRecordArgABI(RT, getCXXABI());
     if (RAA == CGCXXABI::RAA_Indirect) {
@@ -125,8 +125,8 @@ ABIArgInfo LanaiABIInfo::classifyArgumentType(QualType Ty,
   }
 
   // Treat an enum type as its underlying type.
-  if (const auto *EnumTy = Ty->getAs<EnumType>())
-    Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = Ty->getAsEnumDecl())
+    Ty = ED->getIntegerType();
 
   bool InReg = shouldUseInReg(Ty, State);
 
diff --git a/clang/lib/CodeGen/Targets/LoongArch.cpp b/clang/lib/CodeGen/Targets/LoongArch.cpp
index af863e6101e2c..1f344d6582510 100644
--- a/clang/lib/CodeGen/Targets/LoongArch.cpp
+++ b/clang/lib/CodeGen/Targets/LoongArch.cpp
@@ -149,7 +149,7 @@ bool LoongArchABIInfo::detectFARsEligibleStructHelper(
     QualType EltTy = ATy->getElementType();
     // Non-zero-length arrays of empty records make the struct ineligible to be
     // passed via FARs in C++.
-    if (const auto *RTy = EltTy->getAs<RecordType>()) {
+    if (const auto *RTy = EltTy->getAsCanonical<RecordType>()) {
       if (ArraySize != 0 && isa<CXXRecordDecl>(RTy->getOriginalDecl()) &&
           isEmptyRecord(getContext(), EltTy, true, true))
         return false;
@@ -164,7 +164,7 @@ bool LoongArchABIInfo::detectFARsEligibleStructHelper(
     return true;
   }
 
-  if (const auto *RTy = Ty->getAs<RecordType>()) {
+  if (const auto *RTy = Ty->getAsCanonical<RecordType>()) {
     // Structures with either a non-trivial destructor or a non-trivial
     // copy constructor are not eligible for the FP calling convention.
     if (getRecordArgABI(Ty, CGT.getCXXABI()))
@@ -180,10 +180,7 @@ bool LoongArchABIInfo::detectFARsEligibleStructHelper(
     // If this is a C++ record, check the bases first.
     if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
       for (const CXXBaseSpecifier &B : CXXRD->bases()) {
-        const auto *BDecl =
-            cast<CXXRecordDecl>(
-                B.getType()->castAs<RecordType>()->getOriginalDecl())
-                ->getDefinitionOrSelf();
+        const auto *BDecl = B.getType()->castAsCXXRecordDecl();
         if (!detectFARsEligibleStructHelper(
                 B.getType(), CurOff + Layout.getBaseClassOffset(BDecl),
                 Field1Ty, Field1Off, Field2Ty, Field2Off))
@@ -370,8 +367,8 @@ ABIArgInfo LoongArchABIInfo::classifyArgumentType(QualType Ty, bool IsFixed,
 
   if (!isAggregateTypeForABI(Ty) && !Ty->isVectorType()) {
     // Treat an enum type as its underlying type.
-    if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-      Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+    if (const auto *ED = Ty->getAsEnumDecl())
+      Ty = ED->getIntegerType();
 
     // All integral types are promoted to GRLen width.
     if (Size < GRLen && Ty->isIntegralOrEnumerationType())
diff --git a/clang/lib/CodeGen/Targets/Mips.cpp b/clang/lib/CodeGen/Targets/Mips.cpp
index e12a34ce07bbe..f26ab974d699e 100644
--- a/clang/lib/CodeGen/Targets/Mips.cpp
+++ b/clang/lib/CodeGen/Targets/Mips.cpp
@@ -153,7 +153,7 @@ llvm::Type* MipsABIInfo::HandleAggregates(QualType Ty, uint64_t TySize) const {
   if (Ty->isComplexType())
     return CGT.ConvertType(Ty);
 
-  const RecordType *RT = Ty->getAs<RecordType>();
+  const RecordType *RT = Ty->getAsCanonical<RecordType>();
 
   // Unions/vectors are passed in integer registers.
   if (!RT || !RT->isStructureOrClassType()) {
@@ -241,8 +241,8 @@ MipsABIInfo::classifyArgumentType(QualType Ty, uint64_t &Offset) const {
   }
 
   // Treat an enum type as its underlying type.
-  if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-    Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = Ty->getAsEnumDecl())
+    Ty = ED->getIntegerType();
 
   // Make sure we pass indirectly things that are too large.
   if (const auto *EIT = Ty->getAs<BitIntType>())
@@ -261,7 +261,7 @@ MipsABIInfo::classifyArgumentType(QualType Ty, uint64_t &Offset) const {
 
 llvm::Type*
 MipsABIInfo::returnAggregateInRegs(QualType RetTy, uint64_t Size) const {
-  const RecordType *RT = RetTy->getAs<RecordType>();
+  const RecordType *RT = RetTy->getAsCanonical<RecordType>();
   SmallVector<llvm::Type*, 8> RTList;
 
   if (RT && RT->isStructureOrClassType()) {
@@ -332,8 +332,8 @@ ABIArgInfo MipsABIInfo::classifyReturnType(QualType RetTy) const {
   }
 
   // Treat an enum type as its underlying type.
-  if (const EnumType *EnumTy = RetTy->getAs<EnumType>())
-    RetTy = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = RetTy->getAsEnumDecl())
+    RetTy = ED->getIntegerType();
 
   // Make sure we pass indirectly things that are too large.
   if (const auto *EIT = RetTy->getAs<BitIntType>())
diff --git a/clang/lib/CodeGen/Targets/NVPTX.cpp b/clang/lib/CodeGen/Targets/NVPTX.cpp
index e874617796f86..400fa200bf7bc 100644
--- a/clang/lib/CodeGen/Targets/NVPTX.cpp
+++ b/clang/lib/CodeGen/Targets/NVPTX.cpp
@@ -130,10 +130,9 @@ bool NVPTXABIInfo::isUnsupportedType(QualType T) const {
     return true;
   if (const auto *AT = T->getAsArrayTypeUnsafe())
     return isUnsupportedType(AT->getElementType());
-  const auto *RT = T->getAs<RecordType>();
-  if (!RT)
+  const auto *RD = T->getAsRecordDecl();
+  if (!RD)
     return false;
-  const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
 
   // If this is a C++ record, check the bases first.
   if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD))
@@ -173,8 +172,8 @@ ABIArgInfo NVPTXABIInfo::classifyReturnType(QualType RetTy) const {
     return ABIArgInfo::getDirect();
 
   // Treat an enum type as its underlying type.
-  if (const EnumType *EnumTy = RetTy->getAs<EnumType>())
-    RetTy = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = RetTy->getAsEnumDecl())
+    RetTy = ED->getIntegerType();
 
   return (isPromotableIntegerTypeForABI(RetTy) ? ABIArgInfo::getExtend(RetTy)
                                                : ABIArgInfo::getDirect());
@@ -182,8 +181,8 @@ ABIArgInfo NVPTXABIInfo::classifyReturnType(QualType RetTy) const {
 
 ABIArgInfo NVPTXABIInfo::classifyArgumentType(QualType Ty) const {
   // Treat an enum type as its underlying type.
-  if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-    Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = Ty->getAsEnumDecl())
+    Ty = ED->getIntegerType();
 
   // Return aggregates type as indirect by value
   if (isAggregateTypeForABI(Ty)) {
diff --git a/clang/lib/CodeGen/Targets/PPC.cpp b/clang/lib/CodeGen/Targets/PPC.cpp
index 38e76399299ec..380e8c06c46f0 100644
--- a/clang/lib/CodeGen/Targets/PPC.cpp
+++ b/clang/lib/CodeGen/Targets/PPC.cpp
@@ -153,8 +153,8 @@ class AIXTargetCodeGenInfo : public TargetCodeGenInfo {
 // extended to 32/64 bits.
 bool AIXABIInfo::isPromotableTypeForABI(QualType Ty) const {
   // Treat an enum type as its underlying type.
-  if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-    Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = Ty->getAsEnumDecl())
+    Ty = ED->getIntegerType();
 
   // Promotable integer types are required to be promoted by the ABI.
   if (getContext().isPromotableIntegerType(Ty))
@@ -294,10 +294,7 @@ void AIXTargetCodeGenInfo::setTargetAttributes(
     ASTContext &Context = D->getASTContext();
     unsigned Alignment = Context.toBits(Context.getDeclAlign(D)) / 8;
     const auto *Ty = VarD->getType().getTypePtr();
-    const RecordDecl *RDecl =
-        Ty->isRecordType()
-            ? Ty->getAs<RecordType>()->getOriginalDecl()->getDefinitionOrSelf()
-            : nullptr;
+    const RecordDecl *RDecl = Ty->getAsRecordDecl();
 
     bool EmitDiagnostic = UserSpecifiedTOC && GV->hasExternalLinkage();
     auto reportUnsupportedWarning = [&](bool ShouldEmitWarning, StringRef Msg) {
@@ -708,8 +705,8 @@ class PPC64TargetCodeGenInfo : public TargetCodeGenInfo {
 bool
 PPC64_SVR4_ABIInfo::isPromotableTypeForABI(QualType Ty) const {
   // Treat an enum type as its underlying type.
-  if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-    Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = Ty->getAsEnumDecl())
+    Ty = ED->getIntegerType();
 
   // Promotable integer types are required to be promoted by the ABI.
   if (isPromotableIntegerTypeForABI(Ty))
diff --git a/clang/lib/CodeGen/Targets/RISCV.cpp b/clang/lib/CodeGen/Targets/RISCV.cpp
index cbb0255fdb03a..adcf04c38a0f8 100644
--- a/clang/lib/CodeGen/Targets/RISCV.cpp
+++ b/clang/lib/CodeGen/Targets/RISCV.cpp
@@ -233,7 +233,7 @@ bool RISCVABIInfo::detectFPCCEligibleStructHelper(QualType Ty, CharUnits CurOff,
     QualType EltTy = ATy->getElementType();
     // Non-zero-length arrays of empty records make the struct ineligible for
     // the FP calling convention in C++.
-    if (const auto *RTy = EltTy->getAs<RecordType>()) {
+    if (const auto *RTy = EltTy->getAsCanonical<RecordType>()) {
       if (ArraySize != 0 && isa<CXXRecordDecl>(RTy->getOriginalDecl()) &&
           isEmptyRecord(getContext(), EltTy, true, true))
         return false;
@@ -249,7 +249,7 @@ bool RISCVABIInfo::detectFPCCEligibleStructHelper(QualType Ty, CharUnits CurOff,
     return true;
   }
 
-  if (const auto *RTy = Ty->getAs<RecordType>()) {
+  if (const auto *RTy = Ty->getAsCanonical<RecordType>()) {
     // Structures with either a non-trivial destructor or a non-trivial
     // copy constructor are not eligible for the FP calling convention.
     if (getRecordArgABI(Ty, CGT.getCXXABI()))
@@ -264,10 +264,7 @@ bool RISCVABIInfo::detectFPCCEligibleStructHelper(QualType Ty, CharUnits CurOff,
     // If this is a C++ record, check the bases first.
     if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
       for (const CXXBaseSpecifier &B : CXXRD->bases()) {
-        const auto *BDecl =
-            cast<CXXRecordDecl>(
-                B.getType()->castAs<RecordType>()->getOriginalDecl())
-                ->getDefinitionOrSelf();
+        const auto *BDecl = B.getType()->castAsCXXRecordDecl();
         CharUnits BaseOff = Layout.getBaseClassOffset(BDecl);
         bool Ret = detectFPCCEligibleStructHelper(B.getType(), CurOff + BaseOff,
                                                   Field1Ty, Field1Off, Field2Ty,
@@ -680,8 +677,8 @@ ABIArgInfo RISCVABIInfo::classifyArgumentType(QualType Ty, bool IsFixed,
 
   if (!isAggregateTypeForABI(Ty) && !Ty->isVectorType()) {
     // Treat an enum type as its underlying type.
-    if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-      Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+    if (const auto *ED = Ty->getAsEnumDecl())
+      Ty = ED->getIntegerType();
 
     // All integral types are promoted to XLen width
     if (Size < XLen && Ty->isIntegralOrEnumerationType()) {
diff --git a/clang/lib/CodeGen/Targets/SPIR.cpp b/clang/lib/CodeGen/Targets/SPIR.cpp
index 237aea755fa29..ababb25736203 100644
--- a/clang/lib/CodeGen/Targets/SPIR.cpp
+++ b/clang/lib/CodeGen/Targets/SPIR.cpp
@@ -118,11 +118,9 @@ ABIArgInfo SPIRVABIInfo::classifyReturnType(QualType RetTy) const {
   if (!isAggregateTypeForABI(RetTy) || getRecordArgABI(RetTy, getCXXABI()))
     return DefaultABIInfo::classifyReturnType(RetTy);
 
-  if (const RecordType *RT = RetTy->getAs<RecordType>()) {
-    const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
-    if (RD->hasFlexibleArrayMember())
-      return DefaultABIInfo::classifyReturnType(RetTy);
-  }
+  if (const auto *RD = RetTy->getAsRecordDecl();
+      RD && RD->hasFlexibleArrayMember())
+    return DefaultABIInfo::classifyReturnType(RetTy);
 
   // TODO: The AMDGPU ABI is non-trivial to represent in SPIR-V; in order to
   // avoid encoding various architecture specific bits here we return everything
@@ -186,11 +184,9 @@ ABIArgInfo SPIRVABIInfo::classifyArgumentType(QualType Ty) const {
     return getNaturalAlignIndirect(Ty, getDataLayout().getAllocaAddrSpace(),
                                    RAA == CGCXXABI::RAA_DirectInMemory);
 
-  if (const RecordType *RT = Ty->getAs<RecordType>()) {
-    const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
-    if (RD->hasFlexibleArrayMember())
-      return DefaultABIInfo::classifyArgumentType(Ty);
-  }
+  if (const auto *RD = Ty->getAsRecordDecl();
+      RD && RD->hasFlexibleArrayMember())
+    return DefaultABIInfo::classifyArgumentType(Ty);
 
   return ABIArgInfo::getDirect(CGT.ConvertType(Ty), 0u, nullptr, false);
 }
@@ -431,8 +427,7 @@ static llvm::Type *getInlineSpirvType(CodeGenModule &CGM,
     }
     case SpirvOperandKind::TypeId: {
       QualType TypeOperand = Operand.getResultType();
-      if (auto *RT = TypeOperand->getAs<RecordType>()) {
-        auto *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
+      if (const auto *RD = TypeOperand->getAsRecordDecl()) {
         assert(RD->isCompleteDefinition() &&
                "Type completion should have been required in Sema");
 
diff --git a/clang/lib/CodeGen/Targets/Sparc.cpp b/clang/lib/CodeGen/Targets/Sparc.cpp
index 13a42f66f5843..5f3c15d106eb6 100644
--- a/clang/lib/CodeGen/Targets/Sparc.cpp
+++ b/clang/lib/CodeGen/Targets/Sparc.cpp
@@ -237,8 +237,8 @@ SparcV9ABIInfo::classifyType(QualType Ty, unsigned SizeLimit) const {
         /*ByVal=*/false);
 
   // Treat an enum type as its underlying type.
-  if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-    Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = Ty->getAsEnumDecl())
+    Ty = ED->getIntegerType();
 
   // Integer types smaller than a register are extended.
   if (Size < 64 && Ty->isIntegerType())
diff --git a/clang/lib/CodeGen/Targets/SystemZ.cpp b/clang/lib/CodeGen/Targets/SystemZ.cpp
index 38cc4d39126be..9b6b72b1dcc05 100644
--- a/clang/lib/CodeGen/Targets/SystemZ.cpp
+++ b/clang/lib/CodeGen/Targets/SystemZ.cpp
@@ -145,8 +145,8 @@ class SystemZTargetCodeGenInfo : public TargetCodeGenInfo {
 
 bool SystemZABIInfo::isPromotableIntegerTypeForABI(QualType Ty) const {
   // Treat an enum type as its underlying type.
-  if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-    Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = Ty->getAsEnumDecl())
+    Ty = ED->getIntegerType();
 
   // Promotable integer types are required to be promoted by the ABI.
   if (ABIInfo::isPromotableIntegerTypeForABI(Ty))
@@ -208,10 +208,8 @@ llvm::Type *SystemZABIInfo::getFPArgumentType(QualType Ty,
 }
 
 QualType SystemZABIInfo::GetSingleElementType(QualType Ty) const {
-  const RecordType *RT = Ty->getAs<RecordType>();
-
-  if (RT && RT->isStructureOrClassType()) {
-    const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
+  const auto *RD = Ty->getAsRecordDecl();
+  if (RD && RD->isStructureOrClass()) {
     QualType Found;
 
     // If this is a C++ record, check the bases first.
@@ -452,10 +450,9 @@ ABIArgInfo SystemZABIInfo::classifyArgumentType(QualType Ty) const {
                                    /*ByVal=*/false);
 
   // Handle small structures.
-  if (const RecordType *RT = Ty->getAs<RecordType>()) {
+  if (const auto *RD = Ty->getAsRecordDecl()) {
     // Structures with flexible arrays have variable length, so really
     // fail the size test above.
-    const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
     if (RD->hasFlexibleArrayMember())
       return getNaturalAlignIndirect(Ty, getDataLayout().getAllocaAddrSpace(),
                                      /*ByVal=*/false);
@@ -525,8 +522,7 @@ bool SystemZTargetCodeGenInfo::isVectorTypeBased(const Type *Ty,
   if (Ty->isVectorType() && Ctx.getTypeSize(Ty) / 8 >= 16)
       return true;
 
-  if (const auto *RecordTy = Ty->getAs<RecordType>()) {
-    const RecordDecl *RD = RecordTy->getOriginalDecl()->getDefinitionOrSelf();
+  if (const auto *RD = Ty->getAsRecordDecl()) {
     if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD))
       if (CXXRD->hasDefinition())
         for (const auto &I : CXXRD->bases())
diff --git a/clang/lib/CodeGen/Targets/WebAssembly.cpp b/clang/lib/CodeGen/Targets/WebAssembly.cpp
index ac8dcd2a0540a..ebe996a4edd8d 100644
--- a/clang/lib/CodeGen/Targets/WebAssembly.cpp
+++ b/clang/lib/CodeGen/Targets/WebAssembly.cpp
@@ -115,11 +115,9 @@ ABIArgInfo WebAssemblyABIInfo::classifyArgumentType(QualType Ty) const {
       return ABIArgInfo::getDirect(CGT.ConvertType(QualType(SeltTy, 0)));
     // For the experimental multivalue ABI, fully expand all other aggregates
     if (Kind == WebAssemblyABIKind::ExperimentalMV) {
-      const RecordType *RT = Ty->getAs<RecordType>();
-      assert(RT);
+      const auto *RD = Ty->castAsRecordDecl();
       bool HasBitField = false;
-      for (auto *Field :
-           RT->getOriginalDecl()->getDefinitionOrSelf()->fields()) {
+      for (auto *Field : RD->fields()) {
         if (Field->isBitField()) {
           HasBitField = true;
           break;
diff --git a/clang/lib/CodeGen/Targets/X86.cpp b/clang/lib/CodeGen/Targets/X86.cpp
index 61b833a9c1b43..71db63bb89645 100644
--- a/clang/lib/CodeGen/Targets/X86.cpp
+++ b/clang/lib/CodeGen/Targets/X86.cpp
@@ -352,15 +352,15 @@ bool X86_32ABIInfo::shouldReturnTypeInRegister(QualType Ty,
     return shouldReturnTypeInRegister(AT->getElementType(), Context);
 
   // Otherwise, it must be a record type.
-  const RecordType *RT = Ty->getAs<RecordType>();
-  if (!RT) return false;
+  const auto *RD = Ty->getAsRecordDecl();
+  if (!RD)
+    return false;
 
   // FIXME: Traverse bases here too.
 
   // Structure types are passed in register if all fields would be
   // passed in a register.
-  for (const auto *FD :
-       RT->getOriginalDecl()->getDefinitionOrSelf()->fields()) {
+  for (const auto *FD : RD->fields()) {
     // Empty fields are ignored.
     if (isEmptyField(Context, FD, true))
       continue;
@@ -427,11 +427,10 @@ static bool addBaseAndFieldSizes(ASTContext &Context, const CXXRecordDecl *RD,
 /// optimizations.
 bool X86_32ABIInfo::canExpandIndirectArgument(QualType Ty) const {
   // We can only expand structure types.
-  const RecordType *RT = Ty->getAs<RecordType>();
-  if (!RT)
+  const RecordDecl *RD = Ty->getAsRecordDecl();
+  if (!RD)
     return false;
   uint64_t Size = 0;
-  const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
   if (const auto *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
     if (!IsWin32StructABI) {
       // On non-Windows, we have to conservatively match our old bitcode
@@ -508,13 +507,10 @@ ABIArgInfo X86_32ABIInfo::classifyReturnType(QualType RetTy,
   }
 
   if (isAggregateTypeForABI(RetTy)) {
-    if (const RecordType *RT = RetTy->getAs<RecordType>()) {
+    if (const auto *RD = RetTy->getAsRecordDecl();
+        RD && RD->hasFlexibleArrayMember())
       // Structures with flexible arrays are always indirect.
-      if (RT->getOriginalDecl()
-              ->getDefinitionOrSelf()
-              ->hasFlexibleArrayMember())
-        return getIndirectReturnResult(RetTy, State);
-    }
+      return getIndirectReturnResult(RetTy, State);
 
     // If specified, structs and unions are always indirect.
     if (!IsRetSmallStructInRegABI && !RetTy->isAnyComplexType())
@@ -556,8 +552,8 @@ ABIArgInfo X86_32ABIInfo::classifyReturnType(QualType RetTy,
   }
 
   // Treat an enum type as its underlying type.
-  if (const EnumType *EnumTy = RetTy->getAs<EnumType>())
-    RetTy = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = RetTy->getAsEnumDecl())
+    RetTy = ED->getIntegerType();
 
   if (const auto *EIT = RetTy->getAs<BitIntType>())
     if (EIT->getNumBits() > 64)
@@ -756,7 +752,7 @@ ABIArgInfo X86_32ABIInfo::classifyArgumentType(QualType Ty, CCState &State,
   TypeInfo TI = getContext().getTypeInfo(Ty);
 
   // Check with the C++ ABI first.
-  const RecordType *RT = Ty->getAs<RecordType>();
+  const RecordType *RT = Ty->getAsCanonical<RecordType>();
   if (RT) {
     CGCXXABI::RecordArgABI RAA = getRecordArgABI(RT, getCXXABI());
     if (RAA == CGCXXABI::RAA_Indirect) {
@@ -885,9 +881,8 @@ ABIArgInfo X86_32ABIInfo::classifyArgumentType(QualType Ty, CCState &State,
     return ABIArgInfo::getDirect();
   }
 
-
-  if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-    Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = Ty->getAsEnumDecl())
+    Ty = ED->getIntegerType();
 
   bool InReg = shouldPrimitiveUseInReg(Ty, State);
 
@@ -1850,10 +1845,9 @@ void X86_64ABIInfo::classify(QualType Ty, uint64_t OffsetBase, Class &Lo,
     return;
   }
 
-  if (const EnumType *ET = Ty->getAs<EnumType>()) {
+  if (const auto *ED = Ty->getAsEnumDecl()) {
     // Classify the underlying integer type.
-    classify(ET->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType(),
-             OffsetBase, Lo, Hi, isNamedArg);
+    classify(ED->getIntegerType(), OffsetBase, Lo, Hi, isNamedArg);
     return;
   }
 
@@ -2045,7 +2039,7 @@ void X86_64ABIInfo::classify(QualType Ty, uint64_t OffsetBase, Class &Lo,
     return;
   }
 
-  if (const RecordType *RT = Ty->getAs<RecordType>()) {
+  if (const RecordType *RT = Ty->getAsCanonical<RecordType>()) {
     uint64_t Size = getContext().getTypeSize(Ty);
 
     // AMD64-ABI 3.2.3p2: Rule 1. If the size of an object is larger
@@ -2075,11 +2069,7 @@ void X86_64ABIInfo::classify(QualType Ty, uint64_t OffsetBase, Class &Lo,
       for (const auto &I : CXXRD->bases()) {
         assert(!I.isVirtual() && !I.getType()->isDependentType() &&
                "Unexpected base class!");
-        const auto *Base =
-            cast<CXXRecordDecl>(
-                I.getType()->castAs<RecordType>()->getOriginalDecl())
-                ->getDefinitionOrSelf();
-
+        const auto *Base = I.getType()->castAsCXXRecordDecl();
         // Classify this field.
         //
         // AMD64-ABI 3.2.3p2: Rule 3. If the size of the aggregate exceeds a
@@ -2191,8 +2181,8 @@ ABIArgInfo X86_64ABIInfo::getIndirectReturnResult(QualType Ty) const {
   // place naturally.
   if (!isAggregateTypeForABI(Ty)) {
     // Treat an enum type as its underlying type.
-    if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-      Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+    if (const auto *ED = Ty->getAsEnumDecl())
+      Ty = ED->getIntegerType();
 
     if (Ty->isBitIntType())
       return getNaturalAlignIndirect(Ty, getDataLayout().getAllocaAddrSpace());
@@ -2233,8 +2223,8 @@ ABIArgInfo X86_64ABIInfo::getIndirectResult(QualType Ty,
   if (!isAggregateTypeForABI(Ty) && !IsIllegalVectorType(Ty) &&
       !Ty->isBitIntType()) {
     // Treat an enum type as its underlying type.
-    if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-      Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+    if (const auto *ED = Ty->getAsEnumDecl())
+      Ty = ED->getIntegerType();
 
     return (isPromotableIntegerTypeForABI(Ty) ? ABIArgInfo::getExtend(Ty)
                                               : ABIArgInfo::getDirect());
@@ -2354,8 +2344,7 @@ static bool BitsContainNoUserData(QualType Ty, unsigned StartBit,
     return true;
   }
 
-  if (const RecordType *RT = Ty->getAs<RecordType>()) {
-    const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
+  if (const auto *RD = Ty->getAsRecordDecl()) {
     const ASTRecordLayout &Layout = Context.getASTRecordLayout(RD);
 
     // If this is a C++ record, check the bases first.
@@ -2363,10 +2352,7 @@ static bool BitsContainNoUserData(QualType Ty, unsigned StartBit,
       for (const auto &I : CXXRD->bases()) {
         assert(!I.isVirtual() && !I.getType()->isDependentType() &&
                "Unexpected base class!");
-        const auto *Base =
-            cast<CXXRecordDecl>(
-                I.getType()->castAs<RecordType>()->getOriginalDecl())
-                ->getDefinitionOrSelf();
+        const auto *Base = I.getType()->castAsCXXRecordDecl();
 
         // If the base is after the span we care about, ignore it.
         unsigned BaseOffset = Context.toBits(Layout.getBaseClassOffset(Base));
@@ -2646,9 +2632,8 @@ ABIArgInfo X86_64ABIInfo::classifyReturnType(QualType RetTy) const {
     // so that the parameter gets the right LLVM IR attributes.
     if (Hi == NoClass && isa<llvm::IntegerType>(ResType)) {
       // Treat an enum type as its underlying type.
-      if (const EnumType *EnumTy = RetTy->getAs<EnumType>())
-        RetTy =
-            EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+      if (const auto *ED = RetTy->getAsEnumDecl())
+        RetTy = ED->getIntegerType();
 
       if (RetTy->isIntegralOrEnumerationType() &&
           isPromotableIntegerTypeForABI(RetTy))
@@ -2797,8 +2782,8 @@ X86_64ABIInfo::classifyArgumentType(QualType Ty, unsigned freeIntRegs,
     // so that the parameter gets the right LLVM IR attributes.
     if (Hi == NoClass && isa<llvm::IntegerType>(ResType)) {
       // Treat an enum type as its underlying type.
-      if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-        Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+      if (const auto *ED = Ty->getAsEnumDecl())
+        Ty = ED->getIntegerType();
 
       if (Ty->isIntegralOrEnumerationType() &&
           isPromotableIntegerTypeForABI(Ty))
@@ -3324,14 +3309,14 @@ ABIArgInfo WinX86_64ABIInfo::classify(QualType Ty, unsigned &FreeSSERegs,
   if (Ty->isVoidType())
     return ABIArgInfo::getIgnore();
 
-  if (const EnumType *EnumTy = Ty->getAs<EnumType>())
-    Ty = EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = Ty->getAsEnumDecl())
+    Ty = ED->getIntegerType();
 
   TypeInfo Info = getContext().getTypeInfo(Ty);
   uint64_t Width = Info.Width;
   CharUnits Align = getContext().toCharUnitsFromBits(Info.Align);
 
-  const RecordType *RT = Ty->getAs<RecordType>();
+  const RecordType *RT = Ty->getAsCanonical<RecordType>();
   if (RT) {
     if (!IsReturnType) {
       if (CGCXXABI::RecordArgABI RAA = getRecordArgABI(RT, getCXXABI()))
diff --git a/clang/lib/CodeGen/Targets/XCore.cpp b/clang/lib/CodeGen/Targets/XCore.cpp
index d65af5f82b7f8..8199423b8eb7c 100644
--- a/clang/lib/CodeGen/Targets/XCore.cpp
+++ b/clang/lib/CodeGen/Targets/XCore.cpp
@@ -615,7 +615,7 @@ static bool appendType(SmallStringEnc &Enc, QualType QType,
   if (const PointerType *PT = QT->getAs<PointerType>())
     return appendPointerType(Enc, PT, CGM, TSC);
 
-  if (const EnumType *ET = QT->getAs<EnumType>())
+  if (const EnumType *ET = QT->getAsCanonical<EnumType>())
     return appendEnumType(Enc, ET, TSC, QT.getBaseTypeIdentifier());
 
   if (const RecordType *RT = QT->getAsStructureType())
diff --git a/clang/lib/Edit/RewriteObjCFoundationAPI.cpp b/clang/lib/Edit/RewriteObjCFoundationAPI.cpp
index 627a1d6fb3dd5..40f8348241ecc 100644
--- a/clang/lib/Edit/RewriteObjCFoundationAPI.cpp
+++ b/clang/lib/Edit/RewriteObjCFoundationAPI.cpp
@@ -1011,7 +1011,7 @@ static bool rewriteToNumericBoxedExpression(const ObjCMessageExpr *Msg,
       if ((MK == NSAPI::NSNumberWithInteger ||
            MK == NSAPI::NSNumberWithUnsignedInteger) &&
           !isTruncated) {
-        if (OrigTy->getAs<EnumType>() || isEnumConstant(OrigArg))
+        if (OrigTy->isEnumeralType() || isEnumConstant(OrigArg))
           break;
         if ((MK==NSAPI::NSNumberWithInteger) == OrigTy->isSignedIntegerType() &&
             OrigTySize >= Ctx.getTypeSize(Ctx.IntTy))
diff --git a/clang/lib/Frontend/Rewrite/RewriteModernObjC.cpp b/clang/lib/Frontend/Rewrite/RewriteModernObjC.cpp
index 2b55820b38804..42f2d6591b213 100644
--- a/clang/lib/Frontend/Rewrite/RewriteModernObjC.cpp
+++ b/clang/lib/Frontend/Rewrite/RewriteModernObjC.cpp
@@ -852,7 +852,7 @@ RewriteModernObjC::getIvarAccessString(ObjCIvarDecl *D) {
     IvarT = GetGroupRecordTypeForObjCIvarBitfield(D);
 
   if (!IvarT->getAs<TypedefType>() && IvarT->isRecordType()) {
-    RecordDecl *RD = IvarT->castAs<RecordType>()->getOriginalDecl();
+    RecordDecl *RD = IvarT->castAsCanonical<RecordType>()->getOriginalDecl();
     RD = RD->getDefinition();
     if (RD && !RD->getDeclName().getAsIdentifierInfo()) {
       // decltype(((Foo_IMPL*)0)->bar) *
@@ -3638,8 +3638,7 @@ bool RewriteModernObjC::RewriteObjCFieldDeclType(QualType &Type,
     return RewriteObjCFieldDeclType(ElemTy, Result);
   }
   else if (Type->isRecordType()) {
-    RecordDecl *RD =
-        Type->castAs<RecordType>()->getOriginalDecl()->getDefinitionOrSelf();
+    auto *RD = Type->castAsRecordDecl();
     if (RD->isCompleteDefinition()) {
       if (RD->isStruct())
         Result += "\n\tstruct ";
@@ -3660,28 +3659,26 @@ bool RewriteModernObjC::RewriteObjCFieldDeclType(QualType &Type,
       Result += "\t} ";
       return true;
     }
-  }
-  else if (Type->isEnumeralType()) {
-    EnumDecl *ED =
-        Type->castAs<EnumType>()->getOriginalDecl()->getDefinitionOrSelf();
-    if (ED->isCompleteDefinition()) {
-      Result += "\n\tenum ";
-      Result += ED->getName();
-      if (GlobalDefinedTags.count(ED)) {
-        // Enum is globall defined, use it.
-        Result += " ";
-        return true;
-      }
-
-      Result += " {\n";
-      for (const auto *EC : ED->enumerators()) {
-        Result += "\t"; Result += EC->getName(); Result += " = ";
-        Result += toString(EC->getInitVal(), 10);
-        Result += ",\n";
-      }
-      Result += "\t} ";
+  } else if (auto *ED = Type->getAsEnumDecl();
+             ED && ED->isCompleteDefinition()) {
+    Result += "\n\tenum ";
+    Result += ED->getName();
+    if (GlobalDefinedTags.count(ED)) {
+      // Enum is globall defined, use it.
+      Result += " ";
       return true;
     }
+
+    Result += " {\n";
+    for (const auto *EC : ED->enumerators()) {
+      Result += "\t";
+      Result += EC->getName();
+      Result += " = ";
+      Result += toString(EC->getInitVal(), 10);
+      Result += ",\n";
+    }
+    Result += "\t} ";
+    return true;
   }
 
   Result += "\t";
@@ -3733,15 +3730,7 @@ void RewriteModernObjC::RewriteLocallyDefinedNamedAggregates(FieldDecl *fieldDec
 
   auto *IDecl = dyn_cast<ObjCContainerDecl>(fieldDecl->getDeclContext());
 
-  TagDecl *TD = nullptr;
-  if (Type->isRecordType()) {
-    TD = Type->castAs<RecordType>()->getOriginalDecl()->getDefinitionOrSelf();
-  }
-  else if (Type->isEnumeralType()) {
-    TD = Type->castAs<EnumType>()->getOriginalDecl()->getDefinitionOrSelf();
-  }
-
-  if (TD) {
+  if (auto *TD = Type->getAsTagDecl()) {
     if (GlobalDefinedTags.count(TD))
       return;
 
@@ -5723,10 +5712,7 @@ void RewriteModernObjC::HandleDeclInMainFile(Decl *D) {
           }
         }
       } else if (VD->getType()->isRecordType()) {
-        RecordDecl *RD = VD->getType()
-                             ->castAs<RecordType>()
-                             ->getOriginalDecl()
-                             ->getDefinitionOrSelf();
+        auto *RD = VD->getType()->castAsRecordDecl();
         if (RD->isCompleteDefinition())
           RewriteRecordBody(RD);
       }
@@ -7467,7 +7453,8 @@ Stmt *RewriteModernObjC::RewriteObjCIvarRefExpr(ObjCIvarRefExpr *IV) {
         IvarT = GetGroupRecordTypeForObjCIvarBitfield(D);
 
       if (!IvarT->getAs<TypedefType>() && IvarT->isRecordType()) {
-        RecordDecl *RD = IvarT->castAs<RecordType>()->getOriginalDecl();
+        RecordDecl *RD =
+            IvarT->castAsCanonical<RecordType>()->getOriginalDecl();
         RD = RD->getDefinition();
         if (RD && !RD->getDeclName().getAsIdentifierInfo()) {
           // decltype(((Foo_IMPL*)0)->bar) *
diff --git a/clang/lib/Frontend/Rewrite/RewriteObjC.cpp b/clang/lib/Frontend/Rewrite/RewriteObjC.cpp
index 0242fc1f6e827..b9c025de87739 100644
--- a/clang/lib/Frontend/Rewrite/RewriteObjC.cpp
+++ b/clang/lib/Frontend/Rewrite/RewriteObjC.cpp
@@ -4835,10 +4835,7 @@ void RewriteObjC::HandleDeclInMainFile(Decl *D) {
           }
         }
       } else if (VD->getType()->isRecordType()) {
-        RecordDecl *RD = VD->getType()
-                             ->castAs<RecordType>()
-                             ->getOriginalDecl()
-                             ->getDefinitionOrSelf();
+        auto *RD = VD->getType()->castAsRecordDecl();
         if (RD->isCompleteDefinition())
           RewriteRecordBody(RD);
       }
diff --git a/clang/lib/Index/IndexTypeSourceInfo.cpp b/clang/lib/Index/IndexTypeSourceInfo.cpp
index 9a699c3c896de..74c6c116b274e 100644
--- a/clang/lib/Index/IndexTypeSourceInfo.cpp
+++ b/clang/lib/Index/IndexTypeSourceInfo.cpp
@@ -61,7 +61,7 @@ class TypeIndexer : public RecursiveASTVisitor<TypeIndexer> {
     SourceLocation Loc = TL.getNameLoc();
     TypedefNameDecl *ND = TL.getDecl();
     if (ND->isTransparentTag()) {
-      TagDecl *Underlying = ND->getUnderlyingType()->getAsTagDecl();
+      auto *Underlying = ND->getUnderlyingType()->castAsTagDecl();
       return IndexCtx.handleReference(Underlying, Loc, Parent,
                                       ParentDC, SymbolRoleSet(), Relations);
     }
diff --git a/clang/lib/Index/USRGeneration.cpp b/clang/lib/Index/USRGeneration.cpp
index 6bdddcc6967c9..c78d66f9502dd 100644
--- a/clang/lib/Index/USRGeneration.cpp
+++ b/clang/lib/Index/USRGeneration.cpp
@@ -931,7 +931,8 @@ void USRGenerator::VisitType(QualType T) {
         VisitObjCProtocolDecl(Prot);
       return;
     }
-    if (const TemplateTypeParmType *TTP = T->getAs<TemplateTypeParmType>()) {
+    if (const TemplateTypeParmType *TTP =
+            T->getAsCanonical<TemplateTypeParmType>()) {
       Out << 't' << TTP->getDepth() << '.' << TTP->getIndex();
       return;
     }
diff --git a/clang/lib/Interpreter/InterpreterValuePrinter.cpp b/clang/lib/Interpreter/InterpreterValuePrinter.cpp
index 4f5f43e1b3ade..fbbed6c095c5b 100644
--- a/clang/lib/Interpreter/InterpreterValuePrinter.cpp
+++ b/clang/lib/Interpreter/InterpreterValuePrinter.cpp
@@ -101,15 +101,11 @@ static std::string EnumToString(const Value &V) {
   llvm::raw_string_ostream SS(Str);
   ASTContext &Ctx = const_cast<ASTContext &>(V.getASTContext());
 
-  QualType DesugaredTy = V.getType().getDesugaredType(Ctx);
-  const EnumType *EnumTy = DesugaredTy.getNonReferenceType()->getAs<EnumType>();
-  assert(EnumTy && "Fail to cast to enum type");
-
-  EnumDecl *ED = EnumTy->getOriginalDecl()->getDefinitionOrSelf();
   uint64_t Data = V.convertTo<uint64_t>();
   bool IsFirst = true;
-  llvm::APSInt AP = Ctx.MakeIntValue(Data, DesugaredTy);
+  llvm::APSInt AP = Ctx.MakeIntValue(Data, V.getType());
 
+  auto *ED = V.getType()->castAsEnumDecl();
   for (auto I = ED->enumerator_begin(), E = ED->enumerator_end(); I != E; ++I) {
     if (I->getInitVal() == AP) {
       if (!IsFirst)
@@ -665,8 +661,8 @@ __clang_Interpreter_SetValueNoAlloc(void *This, void *OutVal, void *OpaqueType,
   if (VRef.getKind() == Value::K_PtrOrObj) {
     VRef.setPtr(va_arg(args, void *));
   } else {
-    if (const auto *ET = QT->getAs<EnumType>())
-      QT = ET->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+    if (const auto *ED = QT->getAsEnumDecl())
+      QT = ED->getIntegerType();
     switch (QT->castAs<BuiltinType>()->getKind()) {
     default:
       llvm_unreachable("unknown type kind!");
diff --git a/clang/lib/Interpreter/Value.cpp b/clang/lib/Interpreter/Value.cpp
index 84ba508e9cbc8..d4c9d51ffcb61 100644
--- a/clang/lib/Interpreter/Value.cpp
+++ b/clang/lib/Interpreter/Value.cpp
@@ -101,8 +101,8 @@ static Value::Kind ConvertQualTypeToKind(const ASTContext &Ctx, QualType QT) {
   if (Ctx.hasSameType(QT, Ctx.VoidTy))
     return Value::K_Void;
 
-  if (const auto *ET = QT->getAs<EnumType>())
-    QT = ET->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = QT->getAsEnumDecl())
+    QT = ED->getIntegerType();
 
   const auto *BT = QT->getAs<BuiltinType>();
   if (!BT || BT->isNullPtrType())
@@ -147,15 +147,12 @@ Value::Value(const Interpreter *In, void *Ty) : Interp(In), OpaqueType(Ty) {
         } while (ArrTy);
         ElementsSize = static_cast<size_t>(ArrSize.getZExtValue());
       }
-      if (const auto *RT = DtorTy->getAs<RecordType>()) {
-        if (CXXRecordDecl *CXXRD =
-                llvm::dyn_cast<CXXRecordDecl>(RT->getOriginalDecl())) {
-          if (llvm::Expected<llvm::orc::ExecutorAddr> Addr =
-                  Interp.CompileDtorCall(CXXRD->getDefinitionOrSelf()))
-            DtorF = reinterpret_cast<void *>(Addr->getValue());
-          else
-            llvm::logAllUnhandledErrors(Addr.takeError(), llvm::errs());
-        }
+      if (auto *CXXRD = DtorTy->getAsCXXRecordDecl()) {
+        if (llvm::Expected<llvm::orc::ExecutorAddr> Addr =
+                Interp.CompileDtorCall(CXXRD))
+          DtorF = reinterpret_cast<void *>(Addr->getValue());
+        else
+          llvm::logAllUnhandledErrors(Addr.takeError(), llvm::errs());
       }
 
       size_t AllocSize =
diff --git a/clang/lib/Sema/Sema.cpp b/clang/lib/Sema/Sema.cpp
index b870c5af1e588..39fa25f66f3b7 100644
--- a/clang/lib/Sema/Sema.cpp
+++ b/clang/lib/Sema/Sema.cpp
@@ -1882,23 +1882,21 @@ class DeferredDiagnosticsEmitter
     // Visit the dtors of all members
     for (const FieldDecl *FD : RD->fields()) {
       QualType FT = FD->getType();
-      if (const auto *RT = FT->getAs<RecordType>())
-        if (const auto *ClassDecl =
-                dyn_cast<CXXRecordDecl>(RT->getOriginalDecl()))
-          if (const auto *Def = ClassDecl->getDefinition())
-            if (CXXDestructorDecl *MemberDtor = Def->getDestructor())
-              asImpl().visitUsedDecl(MemberDtor->getLocation(), MemberDtor);
+      if (const auto *ClassDecl = FT->getAsCXXRecordDecl();
+          ClassDecl &&
+          (ClassDecl->isBeingDefined() || ClassDecl->isCompleteDefinition()))
+        if (CXXDestructorDecl *MemberDtor = ClassDecl->getDestructor())
+          asImpl().visitUsedDecl(MemberDtor->getLocation(), MemberDtor);
     }
 
     // Also visit base class dtors
     for (const auto &Base : RD->bases()) {
       QualType BaseType = Base.getType();
-      if (const auto *RT = BaseType->getAs<RecordType>())
-        if (const auto *BaseDecl =
-                dyn_cast<CXXRecordDecl>(RT->getOriginalDecl()))
-          if (const auto *Def = BaseDecl->getDefinition())
-            if (CXXDestructorDecl *BaseDtor = Def->getDestructor())
-              asImpl().visitUsedDecl(BaseDtor->getLocation(), BaseDtor);
+      if (const auto *BaseDecl = BaseType->getAsCXXRecordDecl();
+          BaseDecl &&
+          (BaseDecl->isBeingDefined() || BaseDecl->isCompleteDefinition()))
+        if (CXXDestructorDecl *BaseDtor = BaseDecl->getDestructor())
+          asImpl().visitUsedDecl(BaseDtor->getLocation(), BaseDtor);
     }
   }
 
@@ -1909,12 +1907,11 @@ class DeferredDiagnosticsEmitter
         if (VD->isThisDeclarationADefinition() &&
             VD->needsDestruction(S.Context)) {
           QualType VT = VD->getType();
-          if (const auto *RT = VT->getAs<RecordType>())
-            if (const auto *ClassDecl =
-                    dyn_cast<CXXRecordDecl>(RT->getOriginalDecl()))
-              if (const auto *Def = ClassDecl->getDefinition())
-                if (CXXDestructorDecl *Dtor = Def->getDestructor())
-                  asImpl().visitUsedDecl(Dtor->getLocation(), Dtor);
+          if (const auto *ClassDecl = VT->getAsCXXRecordDecl();
+              ClassDecl && (ClassDecl->isBeingDefined() ||
+                            ClassDecl->isCompleteDefinition()))
+            if (CXXDestructorDecl *Dtor = ClassDecl->getDestructor())
+              asImpl().visitUsedDecl(Dtor->getLocation(), Dtor);
         }
 
     Inherited::VisitDeclStmt(DS);
diff --git a/clang/lib/Sema/SemaAccess.cpp b/clang/lib/Sema/SemaAccess.cpp
index ba560d3c52340..17415b4185eff 100644
--- a/clang/lib/Sema/SemaAccess.cpp
+++ b/clang/lib/Sema/SemaAccess.cpp
@@ -439,10 +439,8 @@ static AccessResult MatchesFriend(Sema &S,
 static AccessResult MatchesFriend(Sema &S,
                                   const EffectiveContext &EC,
                                   CanQualType Friend) {
-  if (const RecordType *RT = Friend->getAs<RecordType>())
-    return MatchesFriend(
-        S, EC,
-        cast<CXXRecordDecl>(RT->getOriginalDecl())->getDefinitionOrSelf());
+  if (const auto *RD = Friend->getAsCXXRecordDecl())
+    return MatchesFriend(S, EC, RD);
 
   // TODO: we can do better than this
   if (Friend->isDependentType())
@@ -1786,10 +1784,7 @@ Sema::AccessResult Sema::CheckMemberOperatorAccess(SourceLocation OpLoc,
   if (!getLangOpts().AccessControl || Found.getAccess() == AS_public)
     return AR_accessible;
 
-  const RecordType *RT = ObjectExpr->getType()->castAs<RecordType>();
-  CXXRecordDecl *NamingClass =
-      cast<CXXRecordDecl>(RT->getOriginalDecl())->getDefinitionOrSelf();
-
+  auto *NamingClass = ObjectExpr->getType()->castAsCXXRecordDecl();
   AccessTarget Entity(Context, AccessTarget::Member, NamingClass, Found,
                       ObjectExpr->getType());
   Entity.setDiag(diag::err_access) << ObjectExpr->getSourceRange() << Range;
diff --git a/clang/lib/Sema/SemaBPF.cpp b/clang/lib/Sema/SemaBPF.cpp
index 6428435ed9d2a..be890ab7fa75f 100644
--- a/clang/lib/Sema/SemaBPF.cpp
+++ b/clang/lib/Sema/SemaBPF.cpp
@@ -56,14 +56,9 @@ static bool isValidPreserveTypeInfoArg(Expr *Arg) {
     return true;
 
   // Record type or Enum type.
-  const Type *Ty = ArgType->getUnqualifiedDesugaredType();
-  if (const auto *RT = Ty->getAs<RecordType>()) {
+  if (const auto *RT = ArgType->getAsCanonical<TagType>())
     if (!RT->getOriginalDecl()->getDeclName().isEmpty())
       return true;
-  } else if (const auto *ET = Ty->getAs<EnumType>()) {
-    if (!ET->getOriginalDecl()->getDeclName().isEmpty())
-      return true;
-  }
 
   return false;
 }
@@ -99,13 +94,12 @@ static bool isValidPreserveEnumValueArg(Expr *Arg) {
     return false;
 
   // The type must be EnumType.
-  const Type *Ty = ArgType->getUnqualifiedDesugaredType();
-  const auto *ET = Ty->getAs<EnumType>();
-  if (!ET)
+  const auto *ED = ArgType->getAsEnumDecl();
+  if (!ED)
     return false;
 
   // The enum value must be supported.
-  return llvm::is_contained(ET->getOriginalDecl()->enumerators(), Enumerator);
+  return llvm::is_contained(ED->enumerators(), Enumerator);
 }
 
 bool SemaBPF::CheckBPFBuiltinFunctionCall(unsigned BuiltinID,
diff --git a/clang/lib/Sema/SemaCUDA.cpp b/clang/lib/Sema/SemaCUDA.cpp
index fbf64d3d57050..2e3cbb336a0c8 100644
--- a/clang/lib/Sema/SemaCUDA.cpp
+++ b/clang/lib/Sema/SemaCUDA.cpp
@@ -421,13 +421,10 @@ bool SemaCUDA::inferTargetForImplicitSpecialMember(CXXRecordDecl *ClassDecl,
   }
 
   for (const auto *B : Bases) {
-    const RecordType *BaseType = B->getType()->getAs<RecordType>();
-    if (!BaseType) {
+    auto *BaseClassDecl = B->getType()->getAsCXXRecordDecl();
+    if (!BaseClassDecl)
       continue;
-    }
 
-    CXXRecordDecl *BaseClassDecl =
-        cast<CXXRecordDecl>(BaseType->getOriginalDecl())->getDefinitionOrSelf();
     Sema::SpecialMemberOverloadResult SMOR =
         SemaRef.LookupSpecialMember(BaseClassDecl, CSM,
                                     /* ConstArg */ ConstRHS,
@@ -466,15 +463,11 @@ bool SemaCUDA::inferTargetForImplicitSpecialMember(CXXRecordDecl *ClassDecl,
       continue;
     }
 
-    const RecordType *FieldType =
-        getASTContext().getBaseElementType(F->getType())->getAs<RecordType>();
-    if (!FieldType) {
+    auto *FieldRecDecl =
+        getASTContext().getBaseElementType(F->getType())->getAsCXXRecordDecl();
+    if (!FieldRecDecl)
       continue;
-    }
 
-    CXXRecordDecl *FieldRecDecl =
-        cast<CXXRecordDecl>(FieldType->getOriginalDecl())
-            ->getDefinitionOrSelf();
     Sema::SpecialMemberOverloadResult SMOR =
         SemaRef.LookupSpecialMember(FieldRecDecl, CSM,
                                     /* ConstArg */ ConstRHS && !F->isMutable(),
diff --git a/clang/lib/Sema/SemaCXXScopeSpec.cpp b/clang/lib/Sema/SemaCXXScopeSpec.cpp
index 45de8ff3ba264..ef14dde5203a6 100644
--- a/clang/lib/Sema/SemaCXXScopeSpec.cpp
+++ b/clang/lib/Sema/SemaCXXScopeSpec.cpp
@@ -133,11 +133,8 @@ DeclContext *Sema::computeDeclContext(const CXXScopeSpec &SS,
     return const_cast<NamespaceDecl *>(
         NNS.getAsNamespaceAndPrefix().Namespace->getNamespace());
 
-  case NestedNameSpecifier::Kind::Type: {
-    auto *TD = NNS.getAsType()->getAsTagDecl();
-    assert(TD && "Non-tag type in nested-name-specifier");
-    return TD;
-  }
+  case NestedNameSpecifier::Kind::Type:
+    return NNS.getAsType()->castAsTagDecl();
 
   case NestedNameSpecifier::Kind::Global:
     return Context.getTranslationUnitDecl();
diff --git a/clang/lib/Sema/SemaCast.cpp b/clang/lib/Sema/SemaCast.cpp
index da43848a1a7d0..8ee3e0c7105b9 100644
--- a/clang/lib/Sema/SemaCast.cpp
+++ b/clang/lib/Sema/SemaCast.cpp
@@ -595,13 +595,11 @@ static void diagnoseBadCast(Sema &S, unsigned msg, CastType castType,
     DifferentPtrness--;
   }
   if (!DifferentPtrness) {
-    auto RecFrom = From->getAs<RecordType>();
-    auto RecTo = To->getAs<RecordType>();
-    if (RecFrom && RecTo) {
-      auto DeclFrom = RecFrom->getAsCXXRecordDecl();
+    if (auto *DeclFrom = From->getAsCXXRecordDecl(),
+        *DeclTo = To->getAsCXXRecordDecl();
+        DeclFrom && DeclTo) {
       if (!DeclFrom->isCompleteDefinition())
         S.Diag(DeclFrom->getLocation(), diag::note_type_incomplete) << DeclFrom;
-      auto DeclTo = RecTo->getAsCXXRecordDecl();
       if (!DeclTo->isCompleteDefinition())
         S.Diag(DeclTo->getLocation(), diag::note_type_incomplete) << DeclTo;
     }
@@ -865,7 +863,7 @@ void CastOperation::CheckDynamicCast() {
     return;
   }
 
-  const RecordType *DestRecord = DestPointee->getAs<RecordType>();
+  const auto *DestRecord = DestPointee->getAsCanonical<RecordType>();
   if (DestPointee->isVoidType()) {
     assert(DestPointer && "Reference to void is not possible");
   } else if (DestRecord) {
@@ -912,7 +910,7 @@ void CastOperation::CheckDynamicCast() {
     SrcPointee = SrcType;
   }
 
-  const RecordType *SrcRecord = SrcPointee->getAs<RecordType>();
+  const auto *SrcRecord = SrcPointee->getAsCanonical<RecordType>();
   if (SrcRecord) {
     if (Self.RequireCompleteType(OpRange.getBegin(), SrcPointee,
                                  diag::err_bad_cast_incomplete,
@@ -1454,7 +1452,7 @@ static TryCastResult TryStaticCast(Sema &Self, ExprResult &SrcExpr,
   // C++0x 5.2.9p9: A value of a scoped enumeration type can be explicitly
   // converted to an integral type. [...] A value of a scoped enumeration type
   // can also be explicitly converted to a floating-point type [...].
-  if (const EnumType *Enum = SrcType->getAs<EnumType>()) {
+  if (const EnumType *Enum = dyn_cast<EnumType>(SrcType)) {
     if (Enum->getOriginalDecl()->isScoped()) {
       if (DestType->isBooleanType()) {
         Kind = CK_IntegralToBoolean;
@@ -1486,8 +1484,7 @@ static TryCastResult TryStaticCast(Sema &Self, ExprResult &SrcExpr,
     if (SrcType->isIntegralOrEnumerationType()) {
       // [expr.static.cast]p10 If the enumeration type has a fixed underlying
       // type, the value is first converted to that type by integral conversion
-      const EnumType *Enum = DestType->castAs<EnumType>();
-      const EnumDecl *ED = Enum->getOriginalDecl()->getDefinitionOrSelf();
+      const auto *ED = DestType->castAsEnumDecl();
       Kind = ED->isFixed() && ED->getIntegerType()->isBooleanType()
                  ? CK_IntegralToBoolean
                  : CK_IntegralCast;
@@ -1581,11 +1578,11 @@ static TryCastResult TryStaticCast(Sema &Self, ExprResult &SrcExpr,
 
   // See if it looks like the user is trying to convert between
   // related record types, and select a better diagnostic if so.
-  if (auto SrcPointer = SrcType->getAs<PointerType>())
-    if (auto DestPointer = DestType->getAs<PointerType>())
-      if (SrcPointer->getPointeeType()->getAs<RecordType>() &&
-          DestPointer->getPointeeType()->getAs<RecordType>())
-       msg = diag::err_bad_cxx_cast_unrelated_class;
+  if (const auto *SrcPointer = SrcType->getAs<PointerType>())
+    if (const auto *DestPointer = DestType->getAs<PointerType>())
+      if (SrcPointer->getPointeeType()->isRecordType() &&
+          DestPointer->getPointeeType()->isRecordType())
+        msg = diag::err_bad_cxx_cast_unrelated_class;
 
   if (SrcType->isMatrixType() && DestType->isMatrixType()) {
     if (Self.CheckMatrixCast(OpRange, DestType, SrcType, Kind)) {
@@ -3097,7 +3094,8 @@ void CastOperation::CheckCStyleCast() {
 
   if (!DestType->isScalarType() && !DestType->isVectorType() &&
       !DestType->isMatrixType()) {
-    if (const RecordType *DestRecordTy = DestType->getAs<RecordType>()) {
+    if (const RecordType *DestRecordTy =
+            DestType->getAsCanonical<RecordType>()) {
       if (Self.Context.hasSameUnqualifiedType(DestType, SrcType)) {
         // GCC struct/union extension: allow cast to self.
         Self.Diag(OpRange.getBegin(), diag::ext_typecheck_cast_nonscalar)
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index 6e777fb9aec8e..7c557f775d567 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -5278,13 +5278,10 @@ bool Sema::BuiltinVAStart(unsigned BuiltinID, CallExpr *TheCall) {
              // extra checking to see what their promotable type actually is.
              if (!Context.isPromotableIntegerType(Type))
                return false;
-             if (!Type->isEnumeralType())
+             const auto *ED = Type->getAsEnumDecl();
+             if (!ED)
                return true;
-             const EnumDecl *ED = Type->castAs<EnumType>()
-                                      ->getOriginalDecl()
-                                      ->getDefinitionOrSelf();
-             return !(ED &&
-                      Context.typesAreCompatible(ED->getPromotionType(), Type));
+             return !Context.typesAreCompatible(ED->getPromotionType(), Type);
            }()) {
     unsigned Reason = 0;
     if (Type->isReferenceType())  Reason = 1;
@@ -7891,16 +7888,10 @@ bool DecomposePrintfHandler::HandlePrintfSpecifier(
 template<typename MemberKind>
 static llvm::SmallPtrSet<MemberKind*, 1>
 CXXRecordMembersNamed(StringRef Name, Sema &S, QualType Ty) {
-  const RecordType *RT = Ty->getAs<RecordType>();
+  auto *RD = Ty->getAsCXXRecordDecl();
   llvm::SmallPtrSet<MemberKind*, 1> Results;
 
-  if (!RT)
-    return Results;
-  CXXRecordDecl *RD = dyn_cast<CXXRecordDecl>(RT->getOriginalDecl());
-  if (!RD)
-    return Results;
-  RD = RD->getDefinition();
-  if (!RD)
+  if (!RD || !(RD->isBeingDefined() || RD->isCompleteDefinition()))
     return Results;
 
   LookupResult R(S, &S.Context.Idents.get(Name), SourceLocation(),
@@ -8438,10 +8429,9 @@ CheckPrintfHandler::checkFormatExpr(const analyze_printf::PrintfSpecifier &FS,
   bool IsEnum = false;
   bool IsScopedEnum = false;
   QualType IntendedTy = ExprTy;
-  if (auto EnumTy = ExprTy->getAs<EnumType>()) {
-    IntendedTy =
-        EnumTy->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
-    if (EnumTy->isUnscopedEnumerationType()) {
+  if (const auto *ED = ExprTy->getAsEnumDecl()) {
+    IntendedTy = ED->getIntegerType();
+    if (!ED->isScoped()) {
       ExprTy = IntendedTy;
       // This controls whether we're talking about the underlying type or not,
       // which we only want to do when it's an unscoped enum.
@@ -9741,10 +9731,7 @@ struct SearchNonTrivialToInitializeField
     S.DiagRuntimeBehavior(SL, E, S.PDiag(diag::note_nontrivial_field) << 1);
   }
   void visitStruct(QualType FT, SourceLocation SL) {
-    for (const FieldDecl *FD : FT->castAs<RecordType>()
-                                   ->getOriginalDecl()
-                                   ->getDefinitionOrSelf()
-                                   ->fields())
+    for (const FieldDecl *FD : FT->castAsRecordDecl()->fields())
       visit(FD->getType(), FD->getLocation());
   }
   void visitArray(QualType::PrimitiveDefaultInitializeKind PDIK,
@@ -9789,10 +9776,7 @@ struct SearchNonTrivialToCopyField
     S.DiagRuntimeBehavior(SL, E, S.PDiag(diag::note_nontrivial_field) << 0);
   }
   void visitStruct(QualType FT, SourceLocation SL) {
-    for (const FieldDecl *FD : FT->castAs<RecordType>()
-                                   ->getOriginalDecl()
-                                   ->getDefinitionOrSelf()
-                                   ->fields())
+    for (const FieldDecl *FD : FT->castAsRecordDecl()->fields())
       visit(FD->getType(), FD->getLocation());
   }
   void visitArray(QualType::PrimitiveCopyKind PCK, const ArrayType *AT,
@@ -10063,17 +10047,16 @@ void Sema::CheckMemaccessArguments(const CallExpr *Call,
         PDiag(diag::warn_arc_object_memaccess)
           << ArgIdx << FnName << PointeeTy
           << Call->getCallee()->getSourceRange());
-    else if (const auto *RT = PointeeTy->getAs<RecordType>()) {
+    else if (const auto *RD = PointeeTy->getAsRecordDecl()) {
 
       // FIXME: Do not consider incomplete types even though they may be
       // completed later. GCC does not diagnose such code, but we may want to
       // consider diagnosing it in the future, perhaps under a different, but
       // related, diagnostic group.
       bool NonTriviallyCopyableCXXRecord =
-          getLangOpts().CPlusPlus && !RT->isIncompleteType() &&
-          !RT->desugar().isTriviallyCopyableType(Context);
+          getLangOpts().CPlusPlus && RD->isCompleteDefinition() &&
+          !PointeeTy.isTriviallyCopyableType(Context);
 
-      const auto *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
       if ((BId == Builtin::BImemset || BId == Builtin::BIbzero) &&
           RD->isNonTrivialToPrimitiveDefaultInitialize()) {
         DiagRuntimeBehavior(Dest->getExprLoc(), Dest,
@@ -10593,20 +10576,15 @@ struct IntRange {
 
     if (!C.getLangOpts().CPlusPlus) {
       // For enum types in C code, use the underlying datatype.
-      if (const auto *ET = dyn_cast<EnumType>(T))
-        T = ET->getOriginalDecl()
-                ->getDefinitionOrSelf()
-                ->getIntegerType()
-                .getDesugaredType(C)
-                .getTypePtr();
-    } else if (const auto *ET = dyn_cast<EnumType>(T)) {
+      if (const auto *ED = T->getAsEnumDecl())
+        T = ED->getIntegerType().getDesugaredType(C).getTypePtr();
+    } else if (auto *Enum = T->getAsEnumDecl()) {
       // For enum types in C++, use the known bit width of the enumerators.
-      EnumDecl *Enum = ET->getOriginalDecl()->getDefinitionOrSelf();
       // In C++11, enums can have a fixed underlying type. Use this type to
       // compute the range.
       if (Enum->isFixed()) {
         return IntRange(C.getIntWidth(QualType(T, 0)),
-                        !ET->isSignedIntegerOrEnumerationType());
+                        !Enum->getIntegerType()->isSignedIntegerType());
       }
 
       unsigned NumPositive = Enum->getNumPositiveBits();
@@ -10642,10 +10620,8 @@ struct IntRange {
       T = CT->getElementType().getTypePtr();
     if (const AtomicType *AT = dyn_cast<AtomicType>(T))
       T = AT->getValueType().getTypePtr();
-    if (const EnumType *ET = dyn_cast<EnumType>(T))
-      T = C.getCanonicalType(
-               ET->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType())
-              .getTypePtr();
+    if (const auto *ED = T->getAsEnumDecl())
+      T = C.getCanonicalType(ED->getIntegerType()).getTypePtr();
 
     if (const auto *EIT = dyn_cast<BitIntType>(T))
       return IntRange(EIT->getNumBits(), EIT->isUnsigned());
@@ -11616,10 +11592,7 @@ static bool AnalyzeBitFieldAssignment(Sema &S, FieldDecl *Bitfield, Expr *Init,
   if (BitfieldType->isBooleanType())
      return false;
 
-  if (BitfieldType->isEnumeralType()) {
-    EnumDecl *BitfieldEnumDecl = BitfieldType->castAs<EnumType>()
-                                     ->getOriginalDecl()
-                                     ->getDefinitionOrSelf();
+  if (auto *BitfieldEnumDecl = BitfieldType->getAsEnumDecl()) {
     // If the underlying enum type was not explicitly specified as an unsigned
     // type and the enum contain only positive values, MSVC++ will cause an
     // inconsistency by storing this as a signed type.
@@ -11648,15 +11621,14 @@ static bool AnalyzeBitFieldAssignment(Sema &S, FieldDecl *Bitfield, Expr *Init,
     // The RHS is not constant.  If the RHS has an enum type, make sure the
     // bitfield is wide enough to hold all the values of the enum without
     // truncation.
-    const auto *EnumTy = OriginalInit->getType()->getAs<EnumType>();
+    const auto *ED = OriginalInit->getType()->getAsEnumDecl();
     const PreferredTypeAttr *PTAttr = nullptr;
-    if (!EnumTy) {
+    if (!ED) {
       PTAttr = Bitfield->getAttr<PreferredTypeAttr>();
       if (PTAttr)
-        EnumTy = PTAttr->getType()->getAs<EnumType>();
+        ED = PTAttr->getType()->getAsEnumDecl();
     }
-    if (EnumTy) {
-      EnumDecl *ED = EnumTy->getOriginalDecl()->getDefinitionOrSelf();
+    if (ED) {
       bool SignedBitfield = BitfieldType->isSignedIntegerOrEnumerationType();
 
       // Enum types are implicitly signed on Windows, so check if there are any
@@ -12720,8 +12692,8 @@ void Sema::CheckImplicitConversion(Expr *E, QualType T, SourceLocation CC,
   // type, to give us better diagnostics.
   Source = Context.getCanonicalType(SourceType).getTypePtr();
 
-  if (const EnumType *SourceEnum = Source->getAs<EnumType>())
-    if (const EnumType *TargetEnum = Target->getAs<EnumType>())
+  if (const EnumType *SourceEnum = Source->getAsCanonical<EnumType>())
+    if (const EnumType *TargetEnum = Target->getAsCanonical<EnumType>())
       if (SourceEnum->getOriginalDecl()->hasNameForLinkage() &&
           TargetEnum->getOriginalDecl()->hasNameForLinkage() &&
           SourceEnum != TargetEnum) {
@@ -15361,17 +15333,13 @@ static bool isLayoutCompatible(const ASTContext &C, QualType T1, QualType T2) {
   if (TC1 != TC2)
     return false;
 
-  if (TC1 == Type::Enum) {
-    return isLayoutCompatible(
-        C, cast<EnumType>(T1)->getOriginalDecl()->getDefinitionOrSelf(),
-        cast<EnumType>(T2)->getOriginalDecl()->getDefinitionOrSelf());
-  } else if (TC1 == Type::Record) {
+  if (TC1 == Type::Enum)
+    return isLayoutCompatible(C, T1->castAsEnumDecl(), T2->castAsEnumDecl());
+  if (TC1 == Type::Record) {
     if (!T1->isStandardLayoutType() || !T2->isStandardLayoutType())
       return false;
 
-    return isLayoutCompatible(
-        C, cast<RecordType>(T1)->getOriginalDecl()->getDefinitionOrSelf(),
-        cast<RecordType>(T2)->getOriginalDecl()->getDefinitionOrSelf());
+    return isLayoutCompatible(C, T1->getAsRecordDecl(), T2->getAsRecordDecl());
   }
 
   return false;
@@ -15728,9 +15696,7 @@ void Sema::RefersToMemberWithReducedAlignment(
       return;
     if (ME->isArrow())
       BaseType = BaseType->getPointeeType();
-    RecordDecl *RD = BaseType->castAs<RecordType>()
-                         ->getOriginalDecl()
-                         ->getDefinitionOrSelf();
+    auto *RD = BaseType->castAsRecordDecl();
     if (RD->isInvalidDecl())
       return;
 
diff --git a/clang/lib/Sema/SemaCodeComplete.cpp b/clang/lib/Sema/SemaCodeComplete.cpp
index c6adead5a65f5..03bf4b3690b13 100644
--- a/clang/lib/Sema/SemaCodeComplete.cpp
+++ b/clang/lib/Sema/SemaCodeComplete.cpp
@@ -5113,10 +5113,7 @@ void SemaCodeCompletion::CodeCompleteExpression(
     PreferredTypeIsPointer = Data.PreferredType->isAnyPointerType() ||
                              Data.PreferredType->isMemberPointerType() ||
                              Data.PreferredType->isBlockPointerType();
-    if (Data.PreferredType->isEnumeralType()) {
-      EnumDecl *Enum = Data.PreferredType->castAs<EnumType>()
-                           ->getOriginalDecl()
-                           ->getDefinitionOrSelf();
+    if (auto *Enum = Data.PreferredType->getAsEnumDecl()) {
       // FIXME: collect covered enumerators in cases like:
       //        if (x == my_enum::one) { ... } else if (x == ^) {}
       AddEnumerators(Results, getASTContext(), Enum, SemaRef.CurContext,
@@ -6240,18 +6237,14 @@ void SemaCodeCompletion::CodeCompleteCase(Scope *S) {
   if (!Switch->getCond())
     return;
   QualType type = Switch->getCond()->IgnoreImplicit()->getType();
-  if (!type->isEnumeralType()) {
+  EnumDecl *Enum = type->getAsEnumDecl();
+  if (!Enum) {
     CodeCompleteExpressionData Data(type);
     Data.IntegralConstantExpression = true;
     CodeCompleteExpression(S, Data);
     return;
   }
 
-  // Code-complete the cases of a switch statement over an enumeration type
-  // by providing the list of
-  EnumDecl *Enum =
-      type->castAs<EnumType>()->getOriginalDecl()->getDefinitionOrSelf();
-
   // Determine which enumerators we have already seen in the switch statement.
   // FIXME: Ideally, we would also be able to look *past* the code-completion
   // token, in case we are code-completing in the middle of the switch and not
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 160d7353cacd9..21a267e244e22 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -9816,8 +9816,7 @@ static void checkIsValidOpenCLKernelParameter(
 
   // At this point we already handled everything except of a RecordType.
   assert(PT->isRecordType() && "Unexpected type.");
-  const RecordDecl *PD =
-      PT->castAs<RecordType>()->getOriginalDecl()->getDefinitionOrSelf();
+  const auto *PD = PT->castAsRecordDecl();
   VisitStack.push_back(PD);
   assert(VisitStack.back() && "First decl null?");
 
@@ -9845,9 +9844,7 @@ static void checkIsValidOpenCLKernelParameter(
              "Unexpected type.");
       const Type *FieldRecTy = FieldTy->getPointeeOrArrayElementType();
 
-      RD = FieldRecTy->castAs<RecordType>()
-               ->getOriginalDecl()
-               ->getDefinitionOrSelf();
+      RD = FieldRecTy->castAsRecordDecl();
     } else {
       RD = cast<RecordDecl>(Next);
     }
@@ -13375,8 +13372,7 @@ struct DiagNonTrivalCUnionDefaultInitializeVisitor
   }
 
   void visitStruct(QualType QT, const FieldDecl *FD, bool InNonTrivialUnion) {
-    const RecordDecl *RD =
-        QT->castAs<RecordType>()->getOriginalDecl()->getDefinitionOrSelf();
+    const auto *RD = QT->castAsRecordDecl();
     if (RD->isUnion()) {
       if (OrigLoc.isValid()) {
         bool IsUnion = false;
@@ -13442,8 +13438,7 @@ struct DiagNonTrivalCUnionDestructedTypeVisitor
   }
 
   void visitStruct(QualType QT, const FieldDecl *FD, bool InNonTrivialUnion) {
-    const RecordDecl *RD =
-        QT->castAs<RecordType>()->getOriginalDecl()->getDefinitionOrSelf();
+    const auto *RD = QT->castAsRecordDecl();
     if (RD->isUnion()) {
       if (OrigLoc.isValid()) {
         bool IsUnion = false;
@@ -13508,8 +13503,7 @@ struct DiagNonTrivalCUnionCopyVisitor
   }
 
   void visitStruct(QualType QT, const FieldDecl *FD, bool InNonTrivialUnion) {
-    const RecordDecl *RD =
-        QT->castAs<RecordType>()->getOriginalDecl()->getDefinitionOrSelf();
+    const auto *RD = QT->castAsRecordDecl();
     if (RD->isUnion()) {
       if (OrigLoc.isValid()) {
         bool IsUnion = false;
@@ -14480,11 +14474,8 @@ void Sema::ActOnUninitializedDecl(Decl *RealDecl) {
     //   version of one of these types, or an array of one of the preceding
     //   types and is declared without an initializer.
     if (getLangOpts().CPlusPlus && Var->hasLocalStorage()) {
-      if (const RecordType *Record
-            = Context.getBaseElementType(Type)->getAs<RecordType>()) {
-        CXXRecordDecl *CXXRecord =
-            cast<CXXRecordDecl>(Record->getOriginalDecl())
-                ->getDefinitionOrSelf();
+      if (const auto *CXXRecord =
+              Context.getBaseElementType(Type)->getAsCXXRecordDecl()) {
         // Mark the function (if we're in one) for further checking even if the
         // looser rules of C++11 do not require such checks, so that we can
         // diagnose incompatibilities with C++98.
@@ -14947,8 +14938,8 @@ void Sema::CheckCompleteVariableDeclaration(VarDecl *var) {
 
   // Require the destructor.
   if (!type->isDependentType())
-    if (const RecordType *recordType = baseType->getAs<RecordType>())
-      FinalizeVarWithDestructor(var, recordType);
+    if (auto *RD = baseType->getAsCXXRecordDecl())
+      FinalizeVarWithDestructor(var, RD);
 
   // If this variable must be emitted, add it as an initializer for the current
   // module.
@@ -19119,17 +19110,16 @@ FieldDecl *Sema::CheckFieldDecl(DeclarationName Name, QualType T,
 
   if (!InvalidDecl && getLangOpts().CPlusPlus) {
     if (Record->isUnion()) {
-      if (const RecordType *RT = EltTy->getAs<RecordType>()) {
-        CXXRecordDecl *RDecl = cast<CXXRecordDecl>(RT->getOriginalDecl());
-        if (RDecl->getDefinition()) {
-          // C++ [class.union]p1: An object of a class with a non-trivial
-          // constructor, a non-trivial copy constructor, a non-trivial
-          // destructor, or a non-trivial copy assignment operator
-          // cannot be a member of a union, nor can an array of such
-          // objects.
-          if (CheckNontrivialField(NewFD))
-            NewFD->setInvalidDecl();
-        }
+      if (const auto *RD = EltTy->getAsCXXRecordDecl();
+          RD && (RD->isBeingDefined() || RD->isCompleteDefinition())) {
+
+        // C++ [class.union]p1: An object of a class with a non-trivial
+        // constructor, a non-trivial copy constructor, a non-trivial
+        // destructor, or a non-trivial copy assignment operator
+        // cannot be a member of a union, nor can an array of such
+        // objects.
+        if (CheckNontrivialField(NewFD))
+          NewFD->setInvalidDecl();
       }
 
       // C++ [class.union]p1: If a union contains a member of reference type,
@@ -19185,55 +19175,51 @@ bool Sema::CheckNontrivialField(FieldDecl *FD) {
     return false;
 
   QualType EltTy = Context.getBaseElementType(FD->getType());
-  if (const RecordType *RT = EltTy->getAs<RecordType>()) {
-    CXXRecordDecl *RDecl =
-        cast<CXXRecordDecl>(RT->getOriginalDecl())->getDefinitionOrSelf();
-    if (RDecl->getDefinition()) {
-      // We check for copy constructors before constructors
-      // because otherwise we'll never get complaints about
-      // copy constructors.
-
-      CXXSpecialMemberKind member = CXXSpecialMemberKind::Invalid;
-      // We're required to check for any non-trivial constructors. Since the
-      // implicit default constructor is suppressed if there are any
-      // user-declared constructors, we just need to check that there is a
-      // trivial default constructor and a trivial copy constructor. (We don't
-      // worry about move constructors here, since this is a C++98 check.)
-      if (RDecl->hasNonTrivialCopyConstructor())
-        member = CXXSpecialMemberKind::CopyConstructor;
-      else if (!RDecl->hasTrivialDefaultConstructor())
-        member = CXXSpecialMemberKind::DefaultConstructor;
-      else if (RDecl->hasNonTrivialCopyAssignment())
-        member = CXXSpecialMemberKind::CopyAssignment;
-      else if (RDecl->hasNonTrivialDestructor())
-        member = CXXSpecialMemberKind::Destructor;
-
-      if (member != CXXSpecialMemberKind::Invalid) {
-        if (!getLangOpts().CPlusPlus11 &&
-            getLangOpts().ObjCAutoRefCount && RDecl->hasObjectMember()) {
-          // Objective-C++ ARC: it is an error to have a non-trivial field of
-          // a union. However, system headers in Objective-C programs
-          // occasionally have Objective-C lifetime objects within unions,
-          // and rather than cause the program to fail, we make those
-          // members unavailable.
-          SourceLocation Loc = FD->getLocation();
-          if (getSourceManager().isInSystemHeader(Loc)) {
-            if (!FD->hasAttr<UnavailableAttr>())
-              FD->addAttr(UnavailableAttr::CreateImplicit(Context, "",
-                            UnavailableAttr::IR_ARCFieldWithOwnership, Loc));
-            return false;
-          }
+  if (const auto *RDecl = EltTy->getAsCXXRecordDecl();
+      RDecl && (RDecl->isBeingDefined() || RDecl->isCompleteDefinition())) {
+    // We check for copy constructors before constructors
+    // because otherwise we'll never get complaints about
+    // copy constructors.
+
+    CXXSpecialMemberKind member = CXXSpecialMemberKind::Invalid;
+    // We're required to check for any non-trivial constructors. Since the
+    // implicit default constructor is suppressed if there are any
+    // user-declared constructors, we just need to check that there is a
+    // trivial default constructor and a trivial copy constructor. (We don't
+    // worry about move constructors here, since this is a C++98 check.)
+    if (RDecl->hasNonTrivialCopyConstructor())
+      member = CXXSpecialMemberKind::CopyConstructor;
+    else if (!RDecl->hasTrivialDefaultConstructor())
+      member = CXXSpecialMemberKind::DefaultConstructor;
+    else if (RDecl->hasNonTrivialCopyAssignment())
+      member = CXXSpecialMemberKind::CopyAssignment;
+    else if (RDecl->hasNonTrivialDestructor())
+      member = CXXSpecialMemberKind::Destructor;
+
+    if (member != CXXSpecialMemberKind::Invalid) {
+      if (!getLangOpts().CPlusPlus11 && getLangOpts().ObjCAutoRefCount &&
+          RDecl->hasObjectMember()) {
+        // Objective-C++ ARC: it is an error to have a non-trivial field of
+        // a union. However, system headers in Objective-C programs
+        // occasionally have Objective-C lifetime objects within unions,
+        // and rather than cause the program to fail, we make those
+        // members unavailable.
+        SourceLocation Loc = FD->getLocation();
+        if (getSourceManager().isInSystemHeader(Loc)) {
+          if (!FD->hasAttr<UnavailableAttr>())
+            FD->addAttr(UnavailableAttr::CreateImplicit(
+                Context, "", UnavailableAttr::IR_ARCFieldWithOwnership, Loc));
+          return false;
         }
-
-        Diag(
-            FD->getLocation(),
-            getLangOpts().CPlusPlus11
-                ? diag::warn_cxx98_compat_nontrivial_union_or_anon_struct_member
-                : diag::err_illegal_union_or_anon_struct_member)
-            << FD->getParent()->isUnion() << FD->getDeclName() << member;
-        DiagnoseNontrivial(RDecl, member);
-        return !getLangOpts().CPlusPlus11;
       }
+
+      Diag(FD->getLocation(),
+           getLangOpts().CPlusPlus11
+               ? diag::warn_cxx98_compat_nontrivial_union_or_anon_struct_member
+               : diag::err_illegal_union_or_anon_struct_member)
+          << FD->getParent()->isUnion() << FD->getDeclName() << member;
+      DiagnoseNontrivial(RDecl, member);
+      return !getLangOpts().CPlusPlus11;
     }
   }
 
@@ -19509,11 +19495,9 @@ bool Sema::EntirelyFunctionPointers(const RecordDecl *Record) {
       return PointeeType.getDesugaredType(Context)->isFunctionType();
     }
     // If a member is a struct entirely of function pointers, that counts too.
-    if (const RecordType *RT = FieldType->getAs<RecordType>()) {
-      const RecordDecl *Record = RT->getOriginalDecl()->getDefinitionOrSelf();
-      if (Record->isStruct() && EntirelyFunctionPointers(Record))
-        return true;
-    }
+    if (const auto *Record = FieldType->getAsRecordDecl();
+        Record && Record->isStruct() && EntirelyFunctionPointers(Record))
+      return true;
     return false;
   };
 
@@ -19667,10 +19651,8 @@ void Sema::ActOnFields(Scope *S, SourceLocation RecLoc, Decl *EnclosingDecl,
       FD->setInvalidDecl();
       EnclosingDecl->setInvalidDecl();
       continue;
-    } else if (const RecordType *FDTTy = FDTy->getAs<RecordType>()) {
-      if (Record && FDTTy->getOriginalDecl()
-                        ->getDefinitionOrSelf()
-                        ->hasFlexibleArrayMember()) {
+    } else if (const auto *RD = FDTy->getAsRecordDecl()) {
+      if (Record && RD->hasFlexibleArrayMember()) {
         // A type which contains a flexible array member is considered to be a
         // flexible array member.
         Record->setHasFlexibleArrayMember(true);
@@ -19696,7 +19678,6 @@ void Sema::ActOnFields(Scope *S, SourceLocation RecLoc, Decl *EnclosingDecl,
         // Ivars can not have abstract class types
         FD->setInvalidDecl();
       }
-      const RecordDecl *RD = FDTTy->getOriginalDecl()->getDefinitionOrSelf();
       if (Record && RD->hasObjectMember())
         Record->setHasObjectMember(true);
       if (Record && RD->hasVolatileMember())
@@ -19730,10 +19711,8 @@ void Sema::ActOnFields(Scope *S, SourceLocation RecLoc, Decl *EnclosingDecl,
         Record->setHasObjectMember(true);
       else if (Context.getAsArrayType(FD->getType())) {
         QualType BaseType = Context.getBaseElementType(FD->getType());
-        if (BaseType->isRecordType() && BaseType->castAs<RecordType>()
-                                            ->getOriginalDecl()
-                                            ->getDefinitionOrSelf()
-                                            ->hasObjectMember())
+        if (const auto *RD = BaseType->getAsRecordDecl();
+            RD && RD->hasObjectMember())
           Record->setHasObjectMember(true);
         else if (BaseType->isObjCObjectPointerType() ||
                  BaseType.isObjCGCStrong())
@@ -19765,10 +19744,8 @@ void Sema::ActOnFields(Scope *S, SourceLocation RecLoc, Decl *EnclosingDecl,
           Record->setHasNonTrivialToPrimitiveDestructCUnion(true);
       }
 
-      if (const auto *RT = FT->getAs<RecordType>()) {
-        if (RT->getOriginalDecl()
-                ->getDefinitionOrSelf()
-                ->getArgPassingRestrictions() ==
+      if (const auto *RD = FT->getAsRecordDecl()) {
+        if (RD->getArgPassingRestrictions() ==
             RecordArgPassingKind::CanNeverPassInRegs)
           Record->setArgPassingRestrictions(
               RecordArgPassingKind::CanNeverPassInRegs);
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 3685bcefbf764..3ded60cd8b073 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -169,7 +169,7 @@ static bool isIntOrBool(Expr *Exp) {
 
 // Check to see if the type is a smart pointer of some kind.  We assume
 // it's a smart pointer if it defines both operator-> and operator*.
-static bool threadSafetyCheckIsSmartPointer(Sema &S, const RecordType* RT) {
+static bool threadSafetyCheckIsSmartPointer(Sema &S, const RecordDecl *Record) {
   auto IsOverloadedOperatorPresent = [&S](const RecordDecl *Record,
                                           OverloadedOperatorKind Op) {
     DeclContextLookupResult Result =
@@ -177,7 +177,6 @@ static bool threadSafetyCheckIsSmartPointer(Sema &S, const RecordType* RT) {
     return !Result.empty();
   };
 
-  const RecordDecl *Record = RT->getOriginalDecl()->getDefinitionOrSelf();
   bool foundStarOperator = IsOverloadedOperatorPresent(Record, OO_Star);
   bool foundArrowOperator = IsOverloadedOperatorPresent(Record, OO_Arrow);
   if (foundStarOperator && foundArrowOperator)
@@ -212,14 +211,14 @@ static bool threadSafetyCheckIsPointer(Sema &S, const Decl *D,
   if (QT->isAnyPointerType())
     return true;
 
-  if (const auto *RT = QT->getAs<RecordType>()) {
+  if (const auto *RD = QT->getAsRecordDecl()) {
     // If it's an incomplete type, it could be a smart pointer; skip it.
     // (We don't want to force template instantiation if we can avoid it,
     // since that would alter the order in which templates are instantiated.)
-    if (RT->isIncompleteType())
+    if (!RD->isCompleteDefinition())
       return true;
 
-    if (threadSafetyCheckIsSmartPointer(S, RT))
+    if (threadSafetyCheckIsSmartPointer(S, RD))
       return true;
   }
 
@@ -229,13 +228,13 @@ static bool threadSafetyCheckIsPointer(Sema &S, const Decl *D,
 
 /// Checks that the passed in QualType either is of RecordType or points
 /// to RecordType. Returns the relevant RecordType, null if it does not exit.
-static const RecordType *getRecordType(QualType QT) {
-  if (const auto *RT = QT->getAs<RecordType>())
-    return RT;
+static const RecordDecl *getRecordDecl(QualType QT) {
+  if (const auto *RD = QT->getAsRecordDecl())
+    return RD;
 
-  // Now check if we point to record type.
-  if (const auto *PT = QT->getAs<PointerType>())
-    return PT->getPointeeType()->getAs<RecordType>();
+  // Now check if we point to a record.
+  if (const auto *PT = QT->getAsCanonical<PointerType>())
+    return PT->getPointeeType()->getAsRecordDecl();
 
   return nullptr;
 }
@@ -257,36 +256,34 @@ static bool checkRecordDeclForAttr(const RecordDecl *RD) {
 }
 
 static bool checkRecordTypeForCapability(Sema &S, QualType Ty) {
-  const RecordType *RT = getRecordType(Ty);
+  const auto *RD = getRecordDecl(Ty);
 
-  if (!RT)
+  if (!RD)
     return false;
 
   // Don't check for the capability if the class hasn't been defined yet.
-  if (RT->isIncompleteType())
+  if (!RD->isCompleteDefinition())
     return true;
 
   // Allow smart pointers to be used as capability objects.
   // FIXME -- Check the type that the smart pointer points to.
-  if (threadSafetyCheckIsSmartPointer(S, RT))
+  if (threadSafetyCheckIsSmartPointer(S, RD))
     return true;
 
-  return checkRecordDeclForAttr<CapabilityAttr>(
-      RT->getOriginalDecl()->getDefinitionOrSelf());
+  return checkRecordDeclForAttr<CapabilityAttr>(RD);
 }
 
 static bool checkRecordTypeForScopedCapability(Sema &S, QualType Ty) {
-  const RecordType *RT = getRecordType(Ty);
+  const auto *RD = getRecordDecl(Ty);
 
-  if (!RT)
+  if (!RD)
     return false;
 
   // Don't check for the capability if the class hasn't been defined yet.
-  if (RT->isIncompleteType())
+  if (!RD->isCompleteDefinition())
     return true;
 
-  return checkRecordDeclForAttr<ScopedLockableAttr>(
-      RT->getOriginalDecl()->getDefinitionOrSelf());
+  return checkRecordDeclForAttr<ScopedLockableAttr>(RD);
 }
 
 static bool checkTypedefTypeForCapability(QualType Ty) {
@@ -401,10 +398,10 @@ static void checkAttrArgsAreCapabilityObjs(Sema &S, Decl *D,
             ArgTy = DRE->getDecl()->getType();
 
     // First see if we can just cast to record type, or pointer to record type.
-    const RecordType *RT = getRecordType(ArgTy);
+    const auto *RD = getRecordDecl(ArgTy);
 
     // Now check if we index into a record type function param.
-    if(!RT && ParamIdxOk) {
+    if (!RD && ParamIdxOk) {
       const auto *FD = dyn_cast<FunctionDecl>(D);
       const auto *IL = dyn_cast<IntegerLiteral>(ArgExp);
       if(FD && IL) {
@@ -3639,7 +3636,7 @@ static void handleInitPriorityAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
   QualType T = cast<VarDecl>(D)->getType();
   if (S.Context.getAsArrayType(T))
     T = S.Context.getBaseElementType(T);
-  if (!T->getAs<RecordType>()) {
+  if (!T->isRecordType()) {
     S.Diag(AL.getLoc(), diag::err_init_priority_object_attr);
     AL.setInvalid();
     return;
@@ -4163,10 +4160,7 @@ static void handleTransparentUnionAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
   RecordDecl *RD = nullptr;
   const auto *TD = dyn_cast<TypedefNameDecl>(D);
   if (TD && TD->getUnderlyingType()->isUnionType())
-    RD = TD->getUnderlyingType()
-             ->getAsUnionType()
-             ->getOriginalDecl()
-             ->getDefinitionOrSelf();
+    RD = TD->getUnderlyingType()->getAsRecordDecl();
   else
     RD = dyn_cast<RecordDecl>(D);
 
@@ -4739,14 +4733,14 @@ void Sema::AddModeAttr(Decl *D, const AttributeCommonInfo &CI,
   // GCC allows 'mode' attribute on enumeration types (even incomplete), except
   // for vector modes. So, 'enum X __attribute__((mode(QI)));' forms a complete
   // type, 'enum { A } __attribute__((mode(V4SI)))' is rejected.
-  if ((isa<EnumDecl>(D) || OldElemTy->getAs<EnumType>()) &&
+  if ((isa<EnumDecl>(D) || OldElemTy->isEnumeralType()) &&
       VectorSize.getBoolValue()) {
     Diag(AttrLoc, diag::err_enum_mode_vector_type) << Name << CI.getRange();
     return;
   }
   bool IntegralOrAnyEnumType = (OldElemTy->isIntegralOrEnumerationType() &&
                                 !OldElemTy->isBitIntType()) ||
-                               OldElemTy->getAs<EnumType>();
+                               OldElemTy->isEnumeralType();
 
   if (!OldElemTy->getAs<BuiltinType>() && !OldElemTy->isComplexType() &&
       !IntegralOrAnyEnumType)
diff --git a/clang/lib/Sema/SemaDeclCXX.cpp b/clang/lib/Sema/SemaDeclCXX.cpp
index bb57968830a65..4b0dead182543 100644
--- a/clang/lib/Sema/SemaDeclCXX.cpp
+++ b/clang/lib/Sema/SemaDeclCXX.cpp
@@ -2187,10 +2187,7 @@ static bool CheckConstexprCtorInitializer(Sema &SemaRef,
       return false;
     }
   } else if (Field->isAnonymousStructOrUnion()) {
-    const RecordDecl *RD = Field->getType()
-                               ->castAs<RecordType>()
-                               ->getOriginalDecl()
-                               ->getDefinitionOrSelf();
+    const auto *RD = Field->getType()->castAsRecordDecl();
     for (auto *I : RD->fields())
       // If an anonymous union contains an anonymous struct of which any member
       // is initialized, all members must be initialized.
@@ -2925,9 +2922,7 @@ NoteIndirectBases(ASTContext &Context, IndirectBaseSet &Set,
 {
   // Even though the incoming type is a base, it might not be
   // a class -- it could be a template parm, for instance.
-  if (auto Rec = Type->getAs<RecordType>()) {
-    auto Decl = Rec->getAsCXXRecordDecl();
-
+  if (const auto *Decl = Type->getAsCXXRecordDecl()) {
     // Iterate over its bases.
     for (const auto &BaseSpec : Decl->bases()) {
       QualType Base = Context.getCanonicalType(BaseSpec.getType())
@@ -2986,9 +2981,7 @@ bool Sema::AttachBaseSpecifiers(CXXRecordDecl *Class,
       if (Bases.size() > 1)
         NoteIndirectBases(Context, IndirectBaseTypes, NewBaseType);
 
-      if (const RecordType *Record = NewBaseType->getAs<RecordType>()) {
-        const CXXRecordDecl *RD = cast<CXXRecordDecl>(Record->getOriginalDecl())
-                                      ->getDefinitionOrSelf();
+      if (const auto *RD = NewBaseType->getAsCXXRecordDecl()) {
         if (Class->isInterface() &&
               (!RD->isInterfaceLike() ||
                KnownBase->getAccessSpecifier() != AS_public)) {
@@ -5479,7 +5472,8 @@ bool Sema::SetCtorInitializers(CXXConstructorDecl *Constructor, bool AnyErrors,
     CXXCtorInitializer *Member = Initializers[i];
 
     if (Member->isBaseInitializer())
-      Info.AllBaseFields[Member->getBaseClass()->getAs<RecordType>()] = Member;
+      Info.AllBaseFields[Member->getBaseClass()->getAsCanonical<RecordType>()] =
+          Member;
     else {
       Info.AllBaseFields[Member->getAnyMember()->getCanonicalDecl()] = Member;
 
@@ -5507,8 +5501,8 @@ bool Sema::SetCtorInitializers(CXXConstructorDecl *Constructor, bool AnyErrors,
 
   // Push virtual bases before others.
   for (auto &VBase : ClassDecl->vbases()) {
-    if (CXXCtorInitializer *Value
-        = Info.AllBaseFields.lookup(VBase.getType()->getAs<RecordType>())) {
+    if (CXXCtorInitializer *Value = Info.AllBaseFields.lookup(
+            VBase.getType()->getAsCanonical<RecordType>())) {
       // [class.base.init]p7, per DR257:
       //   A mem-initializer where the mem-initializer-id names a virtual base
       //   class is ignored during execution of a constructor of any class that
@@ -5546,8 +5540,8 @@ bool Sema::SetCtorInitializers(CXXConstructorDecl *Constructor, bool AnyErrors,
     if (Base.isVirtual())
       continue;
 
-    if (CXXCtorInitializer *Value
-          = Info.AllBaseFields.lookup(Base.getType()->getAs<RecordType>())) {
+    if (CXXCtorInitializer *Value = Info.AllBaseFields.lookup(
+            Base.getType()->getAsCanonical<RecordType>())) {
       Info.AllToInit.push_back(Value);
     } else if (!AnyErrors) {
       CXXCtorInitializer *CXXBaseInit;
@@ -5635,7 +5629,7 @@ bool Sema::SetCtorInitializers(CXXConstructorDecl *Constructor, bool AnyErrors,
 }
 
 static void PopulateKeysForFields(FieldDecl *Field, SmallVectorImpl<const void*> &IdealInits) {
-  if (const RecordType *RT = Field->getType()->getAs<RecordType>()) {
+  if (const RecordType *RT = Field->getType()->getAsCanonical<RecordType>()) {
     const RecordDecl *RD = RT->getOriginalDecl();
     if (RD->isAnonymousStructOrUnion()) {
       for (auto *Field : RD->getDefinitionOrSelf()->fields())
@@ -7611,12 +7605,9 @@ static bool defaultedSpecialMemberIsConstexpr(
   //      class is a constexpr function, and
   if (!S.getLangOpts().CPlusPlus23) {
     for (const auto &B : ClassDecl->bases()) {
-      const RecordType *BaseType = B.getType()->getAs<RecordType>();
-      if (!BaseType)
+      auto *BaseClassDecl = B.getType()->getAsCXXRecordDecl();
+      if (!BaseClassDecl)
         continue;
-      CXXRecordDecl *BaseClassDecl =
-          cast<CXXRecordDecl>(BaseType->getOriginalDecl())
-              ->getDefinitionOrSelf();
       if (!specialMemberIsConstexpr(S, BaseClassDecl, CSM, 0, ConstArg,
                                     InheritedCtor, Inherited))
         return false;
@@ -7638,7 +7629,7 @@ static bool defaultedSpecialMemberIsConstexpr(
           F->hasInClassInitializer())
         continue;
       QualType BaseType = S.Context.getBaseElementType(F->getType());
-      if (const RecordType *RecordTy = BaseType->getAs<RecordType>()) {
+      if (const RecordType *RecordTy = BaseType->getAsCanonical<RecordType>()) {
         CXXRecordDecl *FieldRecDecl =
             cast<CXXRecordDecl>(RecordTy->getOriginalDecl())
                 ->getDefinitionOrSelf();
@@ -10478,11 +10469,7 @@ struct FindHiddenVirtualMethod {
   /// method overloads virtual methods in a base class without overriding any,
   /// to be used with CXXRecordDecl::lookupInBases().
   bool operator()(const CXXBaseSpecifier *Specifier, CXXBasePath &Path) {
-    RecordDecl *BaseRecord = Specifier->getType()
-                                 ->castAs<RecordType>()
-                                 ->getOriginalDecl()
-                                 ->getDefinitionOrSelf();
-
+    auto *BaseRecord = Specifier->getType()->castAsRecordDecl();
     DeclarationName Name = Method->getDeclName();
     assert(Name.getNameKind() == DeclarationName::Identifier);
 
@@ -10637,7 +10624,8 @@ void Sema::checkIllFormedTrivialABIStruct(CXXRecordDecl &RD) {
       return;
     }
 
-    if (const auto *RT = FT->getBaseElementTypeUnsafe()->getAs<RecordType>())
+    if (const auto *RT =
+            FT->getBaseElementTypeUnsafe()->getAsCanonical<RecordType>())
       if (!RT->isDependentType() &&
           !cast<CXXRecordDecl>(RT->getOriginalDecl()->getDefinitionOrSelf())
                ->canPassInRegisters()) {
@@ -12658,15 +12646,12 @@ Decl *Sema::ActOnUsingEnumDeclaration(Scope *S, AccessSpecifier AS,
     return nullptr;
   }
 
-  auto *Enum = dyn_cast_if_present<EnumDecl>(EnumTy->getAsTagDecl());
+  auto *Enum = EnumTy->getAsEnumDecl();
   if (!Enum) {
     Diag(IdentLoc, diag::err_using_enum_not_enum) << EnumTy;
     return nullptr;
   }
 
-  if (auto *Def = Enum->getDefinition())
-    Enum = Def;
-
   if (TSI == nullptr)
     TSI = Context.getTrivialTypeSourceInfo(EnumTy, IdentLoc);
 
@@ -13937,12 +13922,10 @@ struct SpecialMemberExceptionSpecInfo
 }
 
 bool SpecialMemberExceptionSpecInfo::visitBase(CXXBaseSpecifier *Base) {
-  auto *RT = Base->getType()->getAs<RecordType>();
-  if (!RT)
+  auto *BaseClass = Base->getType()->getAsCXXRecordDecl();
+  if (!BaseClass)
     return false;
 
-  auto *BaseClass =
-      cast<CXXRecordDecl>(RT->getOriginalDecl())->getDefinitionOrSelf();
   Sema::SpecialMemberOverloadResult SMOR = lookupInheritedCtor(BaseClass);
   if (auto *BaseCtor = SMOR.getMethod()) {
     visitSubobjectCall(Base, BaseCtor);
@@ -13966,11 +13949,9 @@ bool SpecialMemberExceptionSpecInfo::visitField(FieldDecl *FD) {
       E = S.BuildCXXDefaultInitExpr(Loc, FD).get();
     if (E)
       ExceptSpec.CalledExpr(E);
-  } else if (auto *RT = S.Context.getBaseElementType(FD->getType())
-                            ->getAs<RecordType>()) {
-    visitClassSubobject(
-        cast<CXXRecordDecl>(RT->getOriginalDecl())->getDefinitionOrSelf(), FD,
-        FD->getType().getCVRQualifiers());
+  } else if (auto *RD = S.Context.getBaseElementType(FD->getType())
+                            ->getAsCXXRecordDecl()) {
+    visitClassSubobject(RD, FD, FD->getType().getCVRQualifiers());
   }
   return false;
 }
@@ -14785,11 +14766,9 @@ buildMemcpyForAssignmentOp(Sema &S, SourceLocation Loc, QualType T,
       S.Context, To, UO_AddrOf, S.Context.getPointerType(To->getType()),
       VK_PRValue, OK_Ordinary, Loc, false, S.CurFPFeatureOverrides());
 
-  const Type *E = T->getBaseElementTypeUnsafe();
-  bool NeedsCollectableMemCpy = E->isRecordType() && E->castAs<RecordType>()
-                                                         ->getOriginalDecl()
-                                                         ->getDefinitionOrSelf()
-                                                         ->hasObjectMember();
+  bool NeedsCollectableMemCpy = false;
+  if (auto *RD = T->getBaseElementTypeUnsafe()->getAsRecordDecl())
+    NeedsCollectableMemCpy = RD->hasObjectMember();
 
   // Create a reference to the __builtin_objc_memmove_collectable function
   StringRef MemCpyName = NeedsCollectableMemCpy ?
@@ -14865,10 +14844,7 @@ buildSingleCopyAssignRecursively(Sema &S, SourceLocation Loc, QualType T,
   //       the class is used (as if by explicit qualification; that is,
   //       ignoring any possible virtual overriding functions in more derived
   //       classes);
-  if (const RecordType *RecordTy = T->getAs<RecordType>()) {
-    CXXRecordDecl *ClassDecl =
-        cast<CXXRecordDecl>(RecordTy->getOriginalDecl())->getDefinitionOrSelf();
-
+  if (auto *ClassDecl = T->getAsCXXRecordDecl()) {
     // Look for operator=.
     DeclarationName Name
       = S.Context.DeclarationNames.getCXXOperatorName(OO_Equal);
@@ -15340,7 +15316,7 @@ void Sema::DefineImplicitCopyAssignment(SourceLocation CurrentLocation,
 
     // Check for members of const-qualified, non-class type.
     QualType BaseType = Context.getBaseElementType(Field->getType());
-    if (!BaseType->getAs<RecordType>() && BaseType.isConstQualified()) {
+    if (!BaseType->isRecordType() && BaseType.isConstQualified()) {
       Diag(ClassDecl->getLocation(), diag::err_uninitialized_member_for_assign)
           << Context.getCanonicalTagType(ClassDecl) << 1
           << Field->getDeclName();
@@ -15729,7 +15705,7 @@ void Sema::DefineImplicitMoveAssignment(SourceLocation CurrentLocation,
 
     // Check for members of const-qualified, non-class type.
     QualType BaseType = Context.getBaseElementType(Field->getType());
-    if (!BaseType->getAs<RecordType>() && BaseType.isConstQualified()) {
+    if (!BaseType->isRecordType() && BaseType.isConstQualified()) {
       Diag(ClassDecl->getLocation(), diag::err_uninitialized_member_for_assign)
           << Context.getCanonicalTagType(ClassDecl) << 1
           << Field->getDeclName();
@@ -16322,15 +16298,14 @@ ExprResult Sema::BuildCXXConstructExpr(
       Constructor);
 }
 
-void Sema::FinalizeVarWithDestructor(VarDecl *VD, const RecordType *Record) {
+void Sema::FinalizeVarWithDestructor(VarDecl *VD, CXXRecordDecl *ClassDecl) {
   if (VD->isInvalidDecl()) return;
   // If initializing the variable failed, don't also diagnose problems with
   // the destructor, they're likely related.
   if (VD->getInit() && VD->getInit()->containsErrors())
     return;
 
-  CXXRecordDecl *ClassDecl =
-      cast<CXXRecordDecl>(Record->getOriginalDecl())->getDefinitionOrSelf();
+  ClassDecl = ClassDecl->getDefinitionOrSelf();
   if (ClassDecl->isInvalidDecl()) return;
   if (ClassDecl->hasIrrelevantDestructor()) return;
   if (ClassDecl->isDependentContext()) return;
@@ -16988,9 +16963,9 @@ checkLiteralOperatorTemplateParameterList(Sema &SemaRef,
     // first template parameter as its type.
     if (PmType && PmArgs && !PmType->isTemplateParameterPack() &&
         PmArgs->isTemplateParameterPack()) {
-      const TemplateTypeParmType *TArgs =
-          PmArgs->getType()->getAs<TemplateTypeParmType>();
-      if (TArgs && TArgs->getDepth() == PmType->getDepth() &&
+      if (const auto *TArgs =
+              PmArgs->getType()->getAsCanonical<TemplateTypeParmType>();
+          TArgs && TArgs->getDepth() == PmType->getDepth() &&
           TArgs->getIndex() == PmType->getIndex()) {
         if (!SemaRef.inTemplateInstantiation())
           SemaRef.Diag(TpDecl->getLocation(),
@@ -17337,7 +17312,7 @@ VarDecl *Sema::BuildExceptionDeclaration(Scope *S, TypeSourceInfo *TInfo,
     Invalid = true;
 
   if (!Invalid && !ExDeclType->isDependentType()) {
-    if (const RecordType *recordType = ExDeclType->getAs<RecordType>()) {
+    if (auto *ClassDecl = ExDeclType->getAsCXXRecordDecl()) {
       // Insulate this from anything else we might currently be parsing.
       EnterExpressionEvaluationContext scope(
           *this, ExpressionEvaluationContext::PotentiallyEvaluated);
@@ -17374,7 +17349,7 @@ VarDecl *Sema::BuildExceptionDeclaration(Scope *S, TypeSourceInfo *TInfo,
         }
 
         // And make sure it's destructable.
-        FinalizeVarWithDestructor(ExDecl, recordType);
+        FinalizeVarWithDestructor(ExDecl, ClassDecl);
       }
     }
   }
@@ -18813,8 +18788,8 @@ bool Sema::CheckOverridingFunctionReturnType(const CXXMethodDecl *New,
     //   that of B::f, the class type in the return type of D::f shall be
     //   complete at the point of declaration of D::f or shall be the class
     //   type D.
-    if (const RecordType *RT = NewClassTy->getAs<RecordType>()) {
-      if (!RT->getOriginalDecl()->isEntityBeingDefined() &&
+    if (const auto *RD = NewClassTy->getAsCXXRecordDecl()) {
+      if (!RD->isBeingDefined() &&
           RequireCompleteType(New->getLocation(), NewClassTy,
                               diag::err_covariant_return_incomplete,
                               New->getDeclName()))
@@ -19169,9 +19144,7 @@ void Sema::MarkVirtualMembersReferenced(SourceLocation Loc,
     return;
 
   for (const auto &I : RD->bases()) {
-    const auto *Base = cast<CXXRecordDecl>(
-                           I.getType()->castAs<RecordType>()->getOriginalDecl())
-                           ->getDefinitionOrSelf();
+    const auto *Base = I.getType()->castAsCXXRecordDecl();
     if (Base->getNumVBases() == 0)
       continue;
     MarkVirtualMembersReferenced(Loc, Base);
diff --git a/clang/lib/Sema/SemaDeclObjC.cpp b/clang/lib/Sema/SemaDeclObjC.cpp
index 88ed83eca243e..98eb5afb7c992 100644
--- a/clang/lib/Sema/SemaDeclObjC.cpp
+++ b/clang/lib/Sema/SemaDeclObjC.cpp
@@ -3848,10 +3848,8 @@ SemaObjC::ObjCContainerKind SemaObjC::getObjCContainerKind() const {
 static bool IsVariableSizedType(QualType T) {
   if (T->isIncompleteArrayType())
     return true;
-  const auto *RecordTy = T->getAs<RecordType>();
-  return (RecordTy && RecordTy->getOriginalDecl()
-                          ->getDefinitionOrSelf()
-                          ->hasFlexibleArrayMember());
+  const auto *RD = T->getAsRecordDecl();
+  return RD && RD->hasFlexibleArrayMember();
 }
 
 static void DiagnoseVariableSizedIvars(Sema &S, ObjCContainerDecl *OCD) {
@@ -3896,15 +3894,11 @@ static void DiagnoseVariableSizedIvars(Sema &S, ObjCContainerDecl *OCD) {
           << ivar->getDeclName() << IvarTy
           << TagTypeKind::Class; // Use "class" for Obj-C.
       IsInvalidIvar = true;
-    } else if (const RecordType *RecordTy = IvarTy->getAs<RecordType>()) {
-      if (RecordTy->getOriginalDecl()
-              ->getDefinitionOrSelf()
-              ->hasFlexibleArrayMember()) {
-        S.Diag(ivar->getLocation(),
-               diag::err_objc_variable_sized_type_not_at_end)
-            << ivar->getDeclName() << IvarTy;
-        IsInvalidIvar = true;
-      }
+    } else if (const auto *RD = IvarTy->getAsRecordDecl();
+               RD && RD->hasFlexibleArrayMember()) {
+      S.Diag(ivar->getLocation(), diag::err_objc_variable_sized_type_not_at_end)
+          << ivar->getDeclName() << IvarTy;
+      IsInvalidIvar = true;
     }
     if (IsInvalidIvar) {
       S.Diag(ivar->getNextIvar()->getLocation(),
@@ -5541,11 +5535,8 @@ void SemaObjC::SetIvarInitializers(ObjCImplementationDecl *ObjCImplementation) {
       AllToInit.push_back(Member);
 
       // Be sure that the destructor is accessible and is marked as referenced.
-      if (const RecordType *RecordTy =
-              Context.getBaseElementType(Field->getType())
-                  ->getAs<RecordType>()) {
-        CXXRecordDecl *RD = cast<CXXRecordDecl>(RecordTy->getOriginalDecl())
-                                ->getDefinitionOrSelf();
+      if (auto *RD = Context.getBaseElementType(Field->getType())
+                         ->getAsCXXRecordDecl()) {
         if (CXXDestructorDecl *Destructor = SemaRef.LookupDestructor(RD)) {
           SemaRef.MarkFunctionReferenced(Field->getLocation(), Destructor);
           SemaRef.CheckDestructorAccess(
diff --git a/clang/lib/Sema/SemaExceptionSpec.cpp b/clang/lib/Sema/SemaExceptionSpec.cpp
index 94413b5b92d22..552c92996dc2e 100644
--- a/clang/lib/Sema/SemaExceptionSpec.cpp
+++ b/clang/lib/Sema/SemaExceptionSpec.cpp
@@ -163,9 +163,8 @@ bool Sema::CheckSpecifiedExceptionType(QualType &T, SourceRange Range) {
     DiagID = diag::ext_incomplete_in_exception_spec;
     ReturnValueOnError = false;
   }
-  if (!(PointeeT->isRecordType() && PointeeT->castAs<RecordType>()
-                                        ->getOriginalDecl()
-                                        ->isEntityBeingDefined()) &&
+  if (auto *RD = PointeeT->getAsRecordDecl();
+      !(RD && RD->isBeingDefined()) &&
       RequireCompleteType(Range.getBegin(), PointeeT, DiagID, Kind, Range))
     return ReturnValueOnError;
 
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index 911c80adfd539..b26fa29257696 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -1528,8 +1528,12 @@ void Sema::checkEnumArithmeticConversions(Expr *LHS, Expr *RHS,
     // are ill-formed.
     if (getLangOpts().CPlusPlus26)
       DiagID = diag::warn_conv_mixed_enum_types_cxx26;
-    else if (!L->castAs<EnumType>()->getOriginalDecl()->hasNameForLinkage() ||
-             !R->castAs<EnumType>()->getOriginalDecl()->hasNameForLinkage()) {
+    else if (!L->castAsCanonical<EnumType>()
+                  ->getOriginalDecl()
+                  ->hasNameForLinkage() ||
+             !R->castAsCanonical<EnumType>()
+                  ->getOriginalDecl()
+                  ->hasNameForLinkage()) {
       // If either enumeration type is unnamed, it's less likely that the
       // user cares about this, but this situation is still deprecated in
       // C++2a. Use a different warning group.
@@ -8605,16 +8609,11 @@ QualType Sema::CheckConditionalOperands(ExprResult &Cond, ExprResult &LHS,
 
   // If both operands are the same structure or union type, the result is that
   // type.
-  if (const RecordType *LHSRT = LHSTy->getAs<RecordType>()) {    // C99 6.5.15p3
-    if (const RecordType *RHSRT = RHSTy->getAs<RecordType>())
-      if (declaresSameEntity(LHSRT->getOriginalDecl(),
-                             RHSRT->getOriginalDecl()))
-        // "If both the operands have structure or union type, the result has
-        // that type."  This implies that CV qualifiers are dropped.
-        return Context.getCommonSugaredType(LHSTy.getUnqualifiedType(),
-                                            RHSTy.getUnqualifiedType());
-    // FIXME: Type of conditional expression must be complete in C mode.
-  }
+  // FIXME: Type of conditional expression must be complete in C mode.
+  if (LHSTy->isRecordType() &&
+      Context.hasSameUnqualifiedType(LHSTy, RHSTy)) // C99 6.5.15p3
+    return Context.getCommonSugaredType(LHSTy.getUnqualifiedType(),
+                                        RHSTy.getUnqualifiedType());
 
   // C99 6.5.15p5: "If both operands have void type, the result has void type."
   // The following || allows only one side to be void (a GCC-ism).
@@ -11406,7 +11405,7 @@ QualType Sema::CheckSubtractionOperands(ExprResult &LHS, ExprResult &RHS,
 }
 
 static bool isScopedEnumerationType(QualType T) {
-  if (const EnumType *ET = T->getAs<EnumType>())
+  if (const EnumType *ET = T->getAsCanonical<EnumType>())
     return ET->getOriginalDecl()->isScoped();
   return false;
 }
@@ -12296,10 +12295,7 @@ static QualType checkArithmeticOrEnumeralThreeWayCompare(Sema &S,
       S.InvalidOperands(Loc, LHS, RHS);
       return QualType();
     }
-    QualType IntType = LHSStrippedType->castAs<EnumType>()
-                           ->getOriginalDecl()
-                           ->getDefinitionOrSelf()
-                           ->getIntegerType();
+    QualType IntType = LHSStrippedType->castAsEnumDecl()->getIntegerType();
     assert(IntType->isArithmeticType());
 
     // We can't use `CK_IntegralCast` when the underlying type is 'bool', so we
@@ -13722,7 +13718,7 @@ static void DiagnoseRecursiveConstFields(Sema &S, const ValueDecl *VD,
 
       // Then we append it to the list to check next in order.
       FieldTy = FieldTy.getCanonicalType();
-      if (const auto *FieldRecTy = FieldTy->getAs<RecordType>()) {
+      if (const auto *FieldRecTy = FieldTy->getAsCanonical<RecordType>()) {
         if (!llvm::is_contained(RecordTypeList, FieldRecTy))
           RecordTypeList.push_back(FieldRecTy);
       }
@@ -13738,7 +13734,7 @@ static void DiagnoseRecursiveConstFields(Sema &S, const Expr *E,
   QualType Ty = E->getType();
   assert(Ty->isRecordType() && "lvalue was not record?");
   SourceRange Range = E->getSourceRange();
-  const RecordType *RTy = Ty.getCanonicalType()->getAs<RecordType>();
+  const auto *RTy = Ty->getAsCanonical<RecordType>();
   bool DiagEmitted = false;
 
   if (const MemberExpr *ME = dyn_cast<MemberExpr>(E))
@@ -16156,11 +16152,10 @@ ExprResult Sema::BuildBuiltinOffsetOf(SourceLocation BuiltinLoc,
       return ExprError();
 
     // Look for the designated field.
-    const RecordType *RC = CurrentType->getAs<RecordType>();
-    if (!RC)
+    auto *RD = CurrentType->getAsRecordDecl();
+    if (!RD)
       return ExprError(Diag(OC.LocEnd, diag::err_offsetof_record_type)
                        << CurrentType);
-    RecordDecl *RD = RC->getOriginalDecl()->getDefinitionOrSelf();
 
     // C++ [lib.support.types]p5:
     //   The macro offsetof accepts a restricted set of type arguments in this
@@ -16557,8 +16552,7 @@ ExprResult Sema::ActOnBlockStmtExpr(SourceLocation CaretLoc,
     auto *Var = cast<VarDecl>(Cap.getVariable());
     Expr *CopyExpr = nullptr;
     if (getLangOpts().CPlusPlus && Cap.isCopyCapture()) {
-      if (const RecordType *Record =
-              Cap.getCaptureType()->getAs<RecordType>()) {
+      if (auto *Record = Cap.getCaptureType()->getAsCXXRecordDecl()) {
         // The capture logic needs the destructor, so make sure we mark it.
         // Usually this is unnecessary because most local variables have
         // their destructors marked at declaration time, but parameters are
@@ -16786,9 +16780,8 @@ ExprResult Sema::BuildVAArgExpr(SourceLocation BuiltinLoc,
       // va_arg. Instead, get the underlying type of the enumeration and pass
       // that.
       QualType UnderlyingType = TInfo->getType();
-      if (const auto *ET = UnderlyingType->getAs<EnumType>())
-        UnderlyingType =
-            ET->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+      if (const auto *ED = UnderlyingType->getAsEnumDecl())
+        UnderlyingType = ED->getIntegerType();
       if (Context.typesAreCompatible(PromoteType, UnderlyingType,
                                      /*CompareUnqualified*/ true))
         PromoteType = QualType();
@@ -18785,19 +18778,16 @@ static bool isVariableCapturable(CapturingScopeInfo *CSI, ValueDecl *Var,
   }
   // Prohibit structs with flexible array members too.
   // We cannot capture what is in the tail end of the struct.
-  if (const RecordType *VTTy = Var->getType()->getAs<RecordType>()) {
-    if (VTTy->getOriginalDecl()
-            ->getDefinitionOrSelf()
-            ->hasFlexibleArrayMember()) {
-      if (Diagnose) {
-        if (IsBlock)
-          S.Diag(Loc, diag::err_ref_flexarray_type);
-        else
-          S.Diag(Loc, diag::err_lambda_capture_flexarray_type) << Var;
-        S.Diag(Var->getLocation(), diag::note_previous_decl) << Var;
-      }
-      return false;
+  if (const auto *VTD = Var->getType()->getAsRecordDecl();
+      VTD && VTD->hasFlexibleArrayMember()) {
+    if (Diagnose) {
+      if (IsBlock)
+        S.Diag(Loc, diag::err_ref_flexarray_type);
+      else
+        S.Diag(Loc, diag::err_lambda_capture_flexarray_type) << Var;
+      S.Diag(Var->getLocation(), diag::note_previous_decl) << Var;
     }
+    return false;
   }
   const bool HasBlocksAttr = Var->hasAttr<BlocksAttr>();
   // Lambdas and captured statements are not allowed to capture __block
@@ -19877,7 +19867,7 @@ ExprResult Sema::CheckLValueToRValueConversionOperand(Expr *E) {
   // C++2a [basic.def.odr]p4:
   //   [...] an expression of non-volatile-qualified non-class type to which
   //   the lvalue-to-rvalue conversion is applied [...]
-  if (E->getType().isVolatileQualified() || E->getType()->getAs<RecordType>())
+  if (E->getType().isVolatileQualified() || E->getType()->isRecordType())
     return E;
 
   ExprResult Result =
diff --git a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/SemaExprCXX.cpp
index 2a82842095446..5a9279d928465 100644
--- a/clang/lib/Sema/SemaExprCXX.cpp
+++ b/clang/lib/Sema/SemaExprCXX.cpp
@@ -543,7 +543,7 @@ ExprResult Sema::BuildCXXTypeId(QualType TypeInfoType,
   QualType T
     = Context.getUnqualifiedArrayType(Operand->getType().getNonReferenceType(),
                                       Quals);
-  if (T->getAs<RecordType>() &&
+  if (T->isRecordType() &&
       RequireCompleteType(TypeidLoc, T, diag::err_incomplete_typeid))
     return ExprError();
 
@@ -570,9 +570,7 @@ ExprResult Sema::BuildCXXTypeId(QualType TypeInfoType,
     }
 
     QualType T = E->getType();
-    if (const RecordType *RecordT = T->getAs<RecordType>()) {
-      CXXRecordDecl *RecordD = cast<CXXRecordDecl>(RecordT->getOriginalDecl())
-                                   ->getDefinitionOrSelf();
+    if (auto *RecordD = T->getAsCXXRecordDecl()) {
       // C++ [expr.typeid]p3:
       //   [...] If the type of the expression is a class type, the class
       //   shall be completely-defined.
@@ -1975,8 +1973,8 @@ static UsualDeallocFnInfo resolveDeallocationOverload(
 static bool doesUsualArrayDeleteWantSize(Sema &S, SourceLocation loc,
                                          TypeAwareAllocationMode PassType,
                                          QualType allocType) {
-  const RecordType *record =
-    allocType->getBaseElementTypeUnsafe()->getAs<RecordType>();
+  const auto *record =
+      allocType->getBaseElementTypeUnsafe()->getAsCanonical<RecordType>();
   if (!record) return false;
 
   // Try to find an operator delete[] in class scope.
@@ -2297,8 +2295,7 @@ ExprResult Sema::BuildCXXNew(SourceRange Range, bool UseGlobal,
       ConvertedSize = PerformImplicitConversion(
           *ArraySize, Context.getSizeType(), AssignmentAction::Converting);
 
-      if (!ConvertedSize.isInvalid() &&
-          (*ArraySize)->getType()->getAs<RecordType>())
+      if (!ConvertedSize.isInvalid() && (*ArraySize)->getType()->isRecordType())
         // Diagnose the compatibility of this conversion.
         Diag(StartLoc, diag::warn_cxx98_compat_array_size_conversion)
           << (*ArraySize)->getType() << 0 << "'size_t'";
@@ -3055,9 +3052,7 @@ bool Sema::FindAllocationFunctions(
   LookupResult FoundDelete(*this, DeleteName, StartLoc, LookupOrdinaryName);
   if (AllocElemType->isRecordType() &&
       DeleteScope != AllocationFunctionScope::Global) {
-    auto *RD = cast<CXXRecordDecl>(
-                   AllocElemType->castAs<RecordType>()->getOriginalDecl())
-                   ->getDefinitionOrSelf();
+    auto *RD = AllocElemType->castAsCXXRecordDecl();
     LookupQualifiedName(FoundDelete, RD);
   }
   if (FoundDelete.isAmbiguous())
@@ -4070,9 +4065,7 @@ Sema::ActOnCXXDelete(SourceLocation StartLoc, bool UseGlobal,
                                    ? diag::err_delete_incomplete
                                    : diag::warn_delete_incomplete,
                                Ex.get())) {
-        if (const RecordType *RT = PointeeElem->getAs<RecordType>())
-          PointeeRD =
-              cast<CXXRecordDecl>(RT->getOriginalDecl())->getDefinitionOrSelf();
+        PointeeRD = PointeeElem->getAsCXXRecordDecl();
       }
     }
 
@@ -4840,10 +4833,7 @@ Sema::PerformImplicitConversion(Expr *From, QualType ToType,
     if (FromType->isVectorType() || ToType->isVectorType())
       StepTy = adjustVectorType(Context, FromType, ToType, &ElTy);
     if (ElTy->isBooleanType()) {
-      assert(FromType->castAs<EnumType>()
-                 ->getOriginalDecl()
-                 ->getDefinitionOrSelf()
-                 ->isFixed() &&
+      assert(FromType->castAsEnumDecl()->isFixed() &&
              SCS.Second == ICK_Integral_Promotion &&
              "only enums with fixed underlying type can promote to bool");
       From = ImpCastExprToType(From, StepTy, CK_IntegralToBoolean, VK_PRValue,
@@ -5529,8 +5519,8 @@ static bool TryClassUnification(Sema &Self, Expr *From, Expr *To,
   //         the same or one is a base class of the other:
   QualType FTy = From->getType();
   QualType TTy = To->getType();
-  const RecordType *FRec = FTy->getAs<RecordType>();
-  const RecordType *TRec = TTy->getAs<RecordType>();
+  const RecordType *FRec = FTy->getAsCanonical<RecordType>();
+  const RecordType *TRec = TTy->getAsCanonical<RecordType>();
   bool FDerivedFromT = FRec && TRec && FRec != TRec &&
                        Self.IsDerivedFrom(QuestionLoc, FTy, TTy);
   if (FRec && TRec && (FRec == TRec || FDerivedFromT ||
@@ -7528,12 +7518,10 @@ ExprResult Sema::IgnoredValueConversions(Expr *E) {
   }
 
   // GCC seems to also exclude expressions of incomplete enum type.
-  if (const EnumType *T = E->getType()->getAs<EnumType>()) {
-    if (!T->getOriginalDecl()->getDefinitionOrSelf()->isComplete()) {
-      // FIXME: stupid workaround for a codegen bug!
-      E = ImpCastExprToType(E, Context.VoidTy, CK_ToVoid).get();
-      return E;
-    }
+  if (const auto *ED = E->getType()->getAsEnumDecl(); ED && !ED->isComplete()) {
+    // FIXME: stupid workaround for a codegen bug!
+    E = ImpCastExprToType(E, Context.VoidTy, CK_ToVoid).get();
+    return E;
   }
 
   ExprResult Res = DefaultFunctionArrayLvalueConversion(E);
diff --git a/clang/lib/Sema/SemaExprObjC.cpp b/clang/lib/Sema/SemaExprObjC.cpp
index 03b5c79cf70e3..331f6e585555b 100644
--- a/clang/lib/Sema/SemaExprObjC.cpp
+++ b/clang/lib/Sema/SemaExprObjC.cpp
@@ -638,8 +638,7 @@ ExprResult SemaObjC::BuildObjCBoxedExpr(SourceRange SR, Expr *ValueExpr) {
     // Look for the appropriate method within NSNumber.
     BoxingMethod = getNSNumberFactoryMethod(*this, Loc, ValueType);
     BoxedType = NSNumberPointer;
-  } else if (const EnumType *ET = ValueType->getAs<EnumType>()) {
-    const EnumDecl *ED = ET->getOriginalDecl()->getDefinitionOrSelf();
+  } else if (const auto *ED = ValueType->getAsEnumDecl()) {
     if (!ED->isComplete()) {
       Diag(Loc, diag::err_objc_incomplete_boxed_expression_type)
         << ValueType << ValueExpr->getSourceRange();
@@ -3846,7 +3845,7 @@ static inline T *getObjCBridgeAttr(const TypedefType *TD) {
   QualType QT = TDNDecl->getUnderlyingType();
   if (QT->isPointerType()) {
     QT = QT->getPointeeType();
-    if (const RecordType *RT = QT->getAs<RecordType>()) {
+    if (const RecordType *RT = QT->getAsCanonical<RecordType>()) {
       for (auto *Redecl :
            RT->getOriginalDecl()->getMostRecentDecl()->redecls()) {
         if (auto *attr = Redecl->getAttr<T>())
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index f87715950c74c..8fff463e5aa67 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -236,9 +236,8 @@ static unsigned calculateLegacyCbufferFieldAlign(const ASTContext &Context,
 static unsigned calculateLegacyCbufferSize(const ASTContext &Context,
                                            QualType T) {
   constexpr unsigned CBufferAlign = 16;
-  if (const RecordType *RT = T->getAs<RecordType>()) {
+  if (const auto *RD = T->getAsRecordDecl()) {
     unsigned Size = 0;
-    const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
     for (const FieldDecl *Field : RD->fields()) {
       QualType Ty = Field->getType();
       unsigned FieldSize = calculateLegacyCbufferSize(Context, Ty);
@@ -364,8 +363,8 @@ static bool isInvalidConstantBufferLeafElementType(const Type *Ty) {
   Ty = Ty->getUnqualifiedDesugaredType();
   if (Ty->isHLSLResourceRecord() || Ty->isHLSLResourceRecordArray())
     return true;
-  if (Ty->isRecordType())
-    return Ty->getAsCXXRecordDecl()->isEmpty();
+  if (const auto *RD = Ty->getAsCXXRecordDecl())
+    return RD->isEmpty();
   if (Ty->isConstantArrayType() &&
       isZeroSizedArray(cast<ConstantArrayType>(Ty)))
     return true;
@@ -387,14 +386,14 @@ static bool requiresImplicitBufferLayoutStructure(const CXXRecordDecl *RD) {
     QualType Ty = Field->getType();
     if (isInvalidConstantBufferLeafElementType(Ty.getTypePtr()))
       return true;
-    if (Ty->isRecordType() &&
-        requiresImplicitBufferLayoutStructure(Ty->getAsCXXRecordDecl()))
+    if (const auto *RD = Ty->getAsCXXRecordDecl();
+        RD && requiresImplicitBufferLayoutStructure(RD))
       return true;
   }
   // check bases
   for (const CXXBaseSpecifier &Base : RD->bases())
     if (requiresImplicitBufferLayoutStructure(
-            Base.getType()->getAsCXXRecordDecl()))
+            Base.getType()->castAsCXXRecordDecl()))
       return true;
   return false;
 }
@@ -459,8 +458,7 @@ static FieldDecl *createFieldForHostLayoutStruct(Sema &S, const Type *Ty,
   if (isInvalidConstantBufferLeafElementType(Ty))
     return nullptr;
 
-  if (Ty->isRecordType()) {
-    CXXRecordDecl *RD = Ty->getAsCXXRecordDecl();
+  if (auto *RD = Ty->getAsCXXRecordDecl()) {
     if (requiresImplicitBufferLayoutStructure(RD)) {
       RD = createHostLayoutStruct(S, RD);
       if (!RD)
@@ -511,7 +509,7 @@ static CXXRecordDecl *createHostLayoutStruct(Sema &S,
     assert(NumBases == 1 && "HLSL supports only one base type");
     (void)NumBases;
     CXXBaseSpecifier Base = *StructDecl->bases_begin();
-    CXXRecordDecl *BaseDecl = Base.getType()->getAsCXXRecordDecl();
+    CXXRecordDecl *BaseDecl = Base.getType()->castAsCXXRecordDecl();
     if (requiresImplicitBufferLayoutStructure(BaseDecl)) {
       BaseDecl = createHostLayoutStruct(S, BaseDecl);
       if (BaseDecl) {
@@ -3194,10 +3192,7 @@ static void BuildFlattenedTypeList(QualType BaseTy,
       List.insert(List.end(), VT->getNumElements(), VT->getElementType());
       continue;
     }
-    if (const auto *RT = dyn_cast<RecordType>(T)) {
-      const CXXRecordDecl *RD = RT->getAsCXXRecordDecl();
-      assert(RD && "HLSL record types should all be CXXRecordDecls!");
-
+    if (const auto *RD = T->getAsCXXRecordDecl()) {
       if (RD->isStandardLayout())
         RD = RD->getStandardLayoutBaseWithFields();
 
@@ -3967,19 +3962,19 @@ class InitListTransformer {
       return true;
     }
 
-    if (auto *RTy = Ty->getAs<RecordType>()) {
-      llvm::SmallVector<const RecordType *> RecordTypes;
-      RecordTypes.push_back(RTy);
-      while (RecordTypes.back()->getAsCXXRecordDecl()->getNumBases()) {
-        CXXRecordDecl *D = RecordTypes.back()->getAsCXXRecordDecl();
+    if (auto *RD = Ty->getAsCXXRecordDecl()) {
+      llvm::SmallVector<CXXRecordDecl *> RecordDecls;
+      RecordDecls.push_back(RD);
+      while (RecordDecls.back()->getNumBases()) {
+        CXXRecordDecl *D = RecordDecls.back();
         assert(D->getNumBases() == 1 &&
                "HLSL doesn't support multiple inheritance");
-        RecordTypes.push_back(D->bases_begin()->getType()->getAs<RecordType>());
+        RecordDecls.push_back(
+            D->bases_begin()->getType()->castAsCXXRecordDecl());
       }
-      while (!RecordTypes.empty()) {
-        const RecordType *RT = RecordTypes.pop_back_val();
-        for (auto *FD :
-             RT->getOriginalDecl()->getDefinitionOrSelf()->fields()) {
+      while (!RecordDecls.empty()) {
+        CXXRecordDecl *RD = RecordDecls.pop_back_val();
+        for (auto *FD : RD->fields()) {
           DeclAccessPair Found = DeclAccessPair::make(FD, FD->getAccess());
           DeclarationNameInfo NameInfo(FD->getDeclName(), E->getBeginLoc());
           ExprResult Res = S.BuildFieldReferenceExpr(
@@ -4016,21 +4011,20 @@ class InitListTransformer {
       for (uint64_t I = 0; I < Size; ++I)
         Inits.push_back(generateInitListsImpl(ElTy));
     }
-    if (auto *RTy = Ty->getAs<RecordType>()) {
-      llvm::SmallVector<const RecordType *> RecordTypes;
-      RecordTypes.push_back(RTy);
-      while (RecordTypes.back()->getAsCXXRecordDecl()->getNumBases()) {
-        CXXRecordDecl *D = RecordTypes.back()->getAsCXXRecordDecl();
+    if (auto *RD = Ty->getAsCXXRecordDecl()) {
+      llvm::SmallVector<CXXRecordDecl *> RecordDecls;
+      RecordDecls.push_back(RD);
+      while (RecordDecls.back()->getNumBases()) {
+        CXXRecordDecl *D = RecordDecls.back();
         assert(D->getNumBases() == 1 &&
                "HLSL doesn't support multiple inheritance");
-        RecordTypes.push_back(D->bases_begin()->getType()->getAs<RecordType>());
+        RecordDecls.push_back(
+            D->bases_begin()->getType()->castAsCXXRecordDecl());
       }
-      while (!RecordTypes.empty()) {
-        const RecordType *RT = RecordTypes.pop_back_val();
-        for (auto *FD :
-             RT->getOriginalDecl()->getDefinitionOrSelf()->fields()) {
+      while (!RecordDecls.empty()) {
+        CXXRecordDecl *RD = RecordDecls.pop_back_val();
+        for (auto *FD : RD->fields())
           Inits.push_back(generateInitListsImpl(FD->getType()));
-        }
       }
     }
     auto *NewInit = new (Ctx) InitListExpr(Ctx, Inits.front()->getBeginLoc(),
diff --git a/clang/lib/Sema/SemaInit.cpp b/clang/lib/Sema/SemaInit.cpp
index 60f9d449fc037..91590da35c1eb 100644
--- a/clang/lib/Sema/SemaInit.cpp
+++ b/clang/lib/Sema/SemaInit.cpp
@@ -774,7 +774,7 @@ void InitListChecker::FillInEmptyInitForField(unsigned Init, FieldDecl *Field,
     = InitializedEntity::InitializeMember(Field, &ParentEntity);
 
   if (Init >= NumInits || !ILE->getInit(Init)) {
-    if (const RecordType *RType = ILE->getType()->getAs<RecordType>())
+    if (const RecordType *RType = ILE->getType()->getAsCanonical<RecordType>())
       if (!RType->getOriginalDecl()->isUnion())
         assert((Init < NumInits || VerifyOnly) &&
                "This ILE should have been expanded");
@@ -921,8 +921,7 @@ InitListChecker::FillInEmptyInitializations(const InitializedEntity &Entity,
   if (ILE->isTransparent())
     return;
 
-  if (const RecordType *RType = ILE->getType()->getAs<RecordType>()) {
-    const RecordDecl *RDecl = RType->getOriginalDecl()->getDefinitionOrSelf();
+  if (const auto *RDecl = ILE->getType()->getAsRecordDecl()) {
     if (RDecl->isUnion() && ILE->getInitializedFieldInUnion()) {
       FillInEmptyInitForField(0, ILE->getInitializedFieldInUnion(), Entity, ILE,
                               RequiresSecondPass, FillWithNoInit);
@@ -1126,8 +1125,7 @@ int InitListChecker::numArrayElements(QualType DeclType) {
 }
 
 int InitListChecker::numStructUnionElements(QualType DeclType) {
-  RecordDecl *structDecl =
-      DeclType->castAs<RecordType>()->getOriginalDecl()->getDefinitionOrSelf();
+  auto *structDecl = DeclType->castAsRecordDecl();
   int InitializableMembers = 0;
   if (auto *CXXRD = dyn_cast<CXXRecordDecl>(structDecl))
     InitializableMembers += CXXRD->getNumBases();
@@ -1156,22 +1154,14 @@ static bool isIdiomaticBraceElisionEntity(const InitializedEntity &Entity) {
 
   // Allows elide brace initialization for aggregates with empty base.
   if (Entity.getKind() == InitializedEntity::EK_Base) {
-    auto *ParentRD = Entity.getParent()
-                         ->getType()
-                         ->castAs<RecordType>()
-                         ->getOriginalDecl()
-                         ->getDefinitionOrSelf();
+    auto *ParentRD = Entity.getParent()->getType()->castAsRecordDecl();
     CXXRecordDecl *CXXRD = cast<CXXRecordDecl>(ParentRD);
     return CXXRD->getNumBases() == 1 && CXXRD->field_empty();
   }
 
   // Allow brace elision if the only subobject is a field.
   if (Entity.getKind() == InitializedEntity::EK_Member) {
-    auto *ParentRD = Entity.getParent()
-                         ->getType()
-                         ->castAs<RecordType>()
-                         ->getOriginalDecl()
-                         ->getDefinitionOrSelf();
+    auto *ParentRD = Entity.getParent()->getType()->castAsRecordDecl();
     if (CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(ParentRD)) {
       if (CXXRD->getNumBases()) {
         return false;
@@ -2348,7 +2338,9 @@ void InitListChecker::CheckStructUnionTypes(
            Field != FieldEnd; ++Field) {
         if (Field->hasInClassInitializer() ||
             (Field->isAnonymousStructOrUnion() &&
-             Field->getType()->getAsCXXRecordDecl()->hasInClassInitializer())) {
+             Field->getType()
+                 ->castAsCXXRecordDecl()
+                 ->hasInClassInitializer())) {
           StructuredList->setInitializedFieldInUnion(*Field);
           // FIXME: Actually build a CXXDefaultInitExpr?
           return;
@@ -4535,11 +4527,7 @@ static void TryConstructorInitialization(Sema &S,
     }
   }
 
-  const RecordType *DestRecordType = DestType->getAs<RecordType>();
-  assert(DestRecordType && "Constructor initialization requires record type");
-  auto *DestRecordDecl = cast<CXXRecordDecl>(DestRecordType->getOriginalDecl())
-                             ->getDefinitionOrSelf();
-
+  auto *DestRecordDecl = DestType->castAsCXXRecordDecl();
   // Build the candidate set directly in the initialization sequence
   // structure, so that it will persist if we fail.
   OverloadCandidateSet &CandidateSet = Sequence.getFailedCandidateSet();
@@ -5026,7 +5014,7 @@ static void TryListInitialization(Sema &S,
       //     class type with a default constructor, the object is
       //     value-initialized.
       if (InitList->getNumInits() == 0) {
-        CXXRecordDecl *RD = DestType->getAsCXXRecordDecl();
+        CXXRecordDecl *RD = DestType->castAsCXXRecordDecl();
         if (S.LookupDefaultConstructor(RD)) {
           TryValueInitialization(S, Entity, Kind, Sequence, InitList);
           return;
@@ -5057,10 +5045,9 @@ static void TryListInitialization(Sema &S,
     //     is direct-list-initialization, the object is initialized with the
     //     value T(v); if a narrowing conversion is required to convert v to
     //     the underlying type of T, the program is ill-formed.
-    auto *ET = DestType->getAs<EnumType>();
     if (S.getLangOpts().CPlusPlus17 &&
-        Kind.getKind() == InitializationKind::IK_DirectList && ET &&
-        ET->getOriginalDecl()->getDefinitionOrSelf()->isFixed() &&
+        Kind.getKind() == InitializationKind::IK_DirectList &&
+        DestType->isEnumeralType() && DestType->castAsEnumDecl()->isFixed() &&
         !S.Context.hasSameUnqualifiedType(E->getType(), DestType) &&
         (E->getType()->isIntegralOrUnscopedEnumerationType() ||
          E->getType()->isFloatingType())) {
@@ -5165,14 +5152,10 @@ static OverloadingResult TryRefInitWithConversionFunction(
   bool AllowExplicitCtors = false;
   bool AllowExplicitConvs = Kind.allowExplicitConversionFunctionsInRefBinding();
 
-  const RecordType *T1RecordType = nullptr;
-  if (AllowRValues && (T1RecordType = T1->getAs<RecordType>()) &&
-      S.isCompleteType(Kind.getLocation(), T1)) {
+  auto *T1RecordDecl = AllowRValues ? T1->getAsCXXRecordDecl() : nullptr;
+  if (T1RecordDecl && S.isCompleteType(Kind.getLocation(), T1)) {
     // The type we're converting to is a class type. Enumerate its constructors
     // to see if there is a suitable conversion.
-    auto *T1RecordDecl = cast<CXXRecordDecl>(T1RecordType->getOriginalDecl())
-                             ->getDefinitionOrSelf();
-
     for (NamedDecl *D : S.LookupConstructors(T1RecordDecl)) {
       auto Info = getConstructorInfo(D);
       if (!Info.Constructor)
@@ -5194,18 +5177,13 @@ static OverloadingResult TryRefInitWithConversionFunction(
       }
     }
   }
-  if (T1RecordType &&
-      T1RecordType->getOriginalDecl()->getDefinitionOrSelf()->isInvalidDecl())
+  if (T1RecordDecl && T1RecordDecl->isInvalidDecl())
     return OR_No_Viable_Function;
 
-  const RecordType *T2RecordType = nullptr;
-  if ((T2RecordType = T2->getAs<RecordType>()) &&
-      S.isCompleteType(Kind.getLocation(), T2)) {
+  const CXXRecordDecl *T2RecordDecl = T2->getAsCXXRecordDecl();
+  if (T2RecordDecl && S.isCompleteType(Kind.getLocation(), T2)) {
     // The type we're converting from is a class type, enumerate its conversion
     // functions.
-    auto *T2RecordDecl = cast<CXXRecordDecl>(T2RecordType->getOriginalDecl())
-                             ->getDefinitionOrSelf();
-
     const auto &Conversions = T2RecordDecl->getVisibleConversionFunctions();
     for (auto I = Conversions.begin(), E = Conversions.end(); I != E; ++I) {
       NamedDecl *D = *I;
@@ -5240,8 +5218,7 @@ static OverloadingResult TryRefInitWithConversionFunction(
       }
     }
   }
-  if (T2RecordType &&
-      T2RecordType->getOriginalDecl()->getDefinitionOrSelf()->isInvalidDecl())
+  if (T2RecordDecl && T2RecordDecl->isInvalidDecl())
     return OR_No_Viable_Function;
 
   SourceLocation DeclLoc = Initializer->getBeginLoc();
@@ -5718,64 +5695,60 @@ static void TryValueInitialization(Sema &S,
   //     -- if T is an array type, then each element is value-initialized;
   T = S.Context.getBaseElementType(T);
 
-  if (const RecordType *RT = T->getAs<RecordType>()) {
-    if (CXXRecordDecl *ClassDecl =
-            dyn_cast<CXXRecordDecl>(RT->getOriginalDecl())) {
-      ClassDecl = ClassDecl->getDefinitionOrSelf();
-      bool NeedZeroInitialization = true;
-      // C++98:
-      // -- if T is a class type (clause 9) with a user-declared constructor
-      //    (12.1), then the default constructor for T is called (and the
-      //    initialization is ill-formed if T has no accessible default
-      //    constructor);
-      // C++11:
-      // -- if T is a class type (clause 9) with either no default constructor
-      //    (12.1 [class.ctor]) or a default constructor that is user-provided
-      //    or deleted, then the object is default-initialized;
-      //
-      // Note that the C++11 rule is the same as the C++98 rule if there are no
-      // defaulted or deleted constructors, so we just use it unconditionally.
-      CXXConstructorDecl *CD = S.LookupDefaultConstructor(ClassDecl);
-      if (!CD || !CD->getCanonicalDecl()->isDefaulted() || CD->isDeleted())
-        NeedZeroInitialization = false;
-
-      // -- if T is a (possibly cv-qualified) non-union class type without a
-      //    user-provided or deleted default constructor, then the object is
-      //    zero-initialized and, if T has a non-trivial default constructor,
-      //    default-initialized;
-      // The 'non-union' here was removed by DR1502. The 'non-trivial default
-      // constructor' part was removed by DR1507.
-      if (NeedZeroInitialization)
-        Sequence.AddZeroInitializationStep(Entity.getType());
-
-      // C++03:
-      // -- if T is a non-union class type without a user-declared constructor,
-      //    then every non-static data member and base class component of T is
-      //    value-initialized;
-      // [...] A program that calls for [...] value-initialization of an
-      // entity of reference type is ill-formed.
-      //
-      // C++11 doesn't need this handling, because value-initialization does not
-      // occur recursively there, and the implicit default constructor is
-      // defined as deleted in the problematic cases.
-      if (!S.getLangOpts().CPlusPlus11 &&
-          ClassDecl->hasUninitializedReferenceMember()) {
-        Sequence.SetFailed(InitializationSequence::FK_TooManyInitsForReference);
-        return;
-      }
-
-      // If this is list-value-initialization, pass the empty init list on when
-      // building the constructor call. This affects the semantics of a few
-      // things (such as whether an explicit default constructor can be called).
-      Expr *InitListAsExpr = InitList;
-      MultiExprArg Args(&InitListAsExpr, InitList ? 1 : 0);
-      bool InitListSyntax = InitList;
+  if (auto *ClassDecl = T->getAsCXXRecordDecl()) {
+    bool NeedZeroInitialization = true;
+    // C++98:
+    // -- if T is a class type (clause 9) with a user-declared constructor
+    //    (12.1), then the default constructor for T is called (and the
+    //    initialization is ill-formed if T has no accessible default
+    //    constructor);
+    // C++11:
+    // -- if T is a class type (clause 9) with either no default constructor
+    //    (12.1 [class.ctor]) or a default constructor that is user-provided
+    //    or deleted, then the object is default-initialized;
+    //
+    // Note that the C++11 rule is the same as the C++98 rule if there are no
+    // defaulted or deleted constructors, so we just use it unconditionally.
+    CXXConstructorDecl *CD = S.LookupDefaultConstructor(ClassDecl);
+    if (!CD || !CD->getCanonicalDecl()->isDefaulted() || CD->isDeleted())
+      NeedZeroInitialization = false;
+
+    // -- if T is a (possibly cv-qualified) non-union class type without a
+    //    user-provided or deleted default constructor, then the object is
+    //    zero-initialized and, if T has a non-trivial default constructor,
+    //    default-initialized;
+    // The 'non-union' here was removed by DR1502. The 'non-trivial default
+    // constructor' part was removed by DR1507.
+    if (NeedZeroInitialization)
+      Sequence.AddZeroInitializationStep(Entity.getType());
 
-      // FIXME: Instead of creating a CXXConstructExpr of array type here,
-      // wrap a class-typed CXXConstructExpr in an ArrayInitLoopExpr.
-      return TryConstructorInitialization(
-          S, Entity, Kind, Args, T, Entity.getType(), Sequence, InitListSyntax);
+    // C++03:
+    // -- if T is a non-union class type without a user-declared constructor,
+    //    then every non-static data member and base class component of T is
+    //    value-initialized;
+    // [...] A program that calls for [...] value-initialization of an
+    // entity of reference type is ill-formed.
+    //
+    // C++11 doesn't need this handling, because value-initialization does not
+    // occur recursively there, and the implicit default constructor is
+    // defined as deleted in the problematic cases.
+    if (!S.getLangOpts().CPlusPlus11 &&
+        ClassDecl->hasUninitializedReferenceMember()) {
+      Sequence.SetFailed(InitializationSequence::FK_TooManyInitsForReference);
+      return;
     }
+
+    // If this is list-value-initialization, pass the empty init list on when
+    // building the constructor call. This affects the semantics of a few
+    // things (such as whether an explicit default constructor can be called).
+    Expr *InitListAsExpr = InitList;
+    MultiExprArg Args(&InitListAsExpr, InitList ? 1 : 0);
+    bool InitListSyntax = InitList;
+
+    // FIXME: Instead of creating a CXXConstructExpr of array type here,
+    // wrap a class-typed CXXConstructExpr in an ArrayInitLoopExpr.
+    return TryConstructorInitialization(
+        S, Entity, Kind, Args, T, Entity.getType(), Sequence, InitListSyntax);
   }
 
   Sequence.AddZeroInitializationStep(Entity.getType());
@@ -5917,10 +5890,8 @@ static void TryOrBuildParenListInitialization(
           AT->getElementType(), llvm::APInt(/*numBits=*/32, ArrayLength),
           /*SizeExpr=*/nullptr, ArraySizeModifier::Normal, 0);
     }
-  } else if (auto *RT = Entity.getType()->getAs<RecordType>()) {
-    bool IsUnion = RT->isUnionType();
-    const auto *RD =
-        cast<CXXRecordDecl>(RT->getOriginalDecl())->getDefinitionOrSelf();
+  } else if (auto *RD = Entity.getType()->getAsCXXRecordDecl()) {
+    bool IsUnion = RD->isUnion();
     if (RD->isInvalidDecl()) {
       // Exit early to avoid confusion when processing members.
       // We do the same for braced list initialization in
@@ -6106,15 +6077,12 @@ static void TryUserDefinedConversion(Sema &S,
   // explicit conversion operators.
   bool AllowExplicit = Kind.AllowExplicit();
 
-  if (const RecordType *DestRecordType = DestType->getAs<RecordType>()) {
+  if (DestType->isRecordType()) {
     // The type we're converting to is a class type. Enumerate its constructors
     // to see if there is a suitable conversion.
-    auto *DestRecordDecl =
-        cast<CXXRecordDecl>(DestRecordType->getOriginalDecl())
-            ->getDefinitionOrSelf();
-
     // Try to complete the type we're converting to.
     if (S.isCompleteType(Kind.getLocation(), DestType)) {
+      auto *DestRecordDecl = DestType->castAsCXXRecordDecl();
       for (NamedDecl *D : S.LookupConstructors(DestRecordDecl)) {
         auto Info = getConstructorInfo(D);
         if (!Info.Constructor)
@@ -6140,17 +6108,14 @@ static void TryUserDefinedConversion(Sema &S,
 
   SourceLocation DeclLoc = Initializer->getBeginLoc();
 
-  if (const RecordType *SourceRecordType = SourceType->getAs<RecordType>()) {
+  if (SourceType->isRecordType()) {
     // The type we're converting from is a class type, enumerate its conversion
     // functions.
 
     // We can only enumerate the conversion functions for a complete type; if
     // the type isn't complete, simply skip this step.
     if (S.isCompleteType(DeclLoc, SourceType)) {
-      auto *SourceRecordDecl =
-          cast<CXXRecordDecl>(SourceRecordType->getOriginalDecl())
-              ->getDefinitionOrSelf();
-
+      auto *SourceRecordDecl = SourceType->castAsCXXRecordDecl();
       const auto &Conversions =
           SourceRecordDecl->getVisibleConversionFunctions();
       for (auto I = Conversions.begin(), E = Conversions.end(); I != E; ++I) {
@@ -6237,7 +6202,7 @@ static void TryUserDefinedConversion(Sema &S,
   Sequence.AddUserConversionStep(Function, Best->FoundDecl, ConvType,
                                  HadMultipleCandidates);
 
-  if (ConvType->getAs<RecordType>()) {
+  if (ConvType->isRecordType()) {
     //   The call is used to direct-initialize [...] the object that is the
     //   destination of the copy-initialization.
     //
@@ -7180,10 +7145,7 @@ static ExprResult CopyObject(Sema &S,
     return CurInit;
   // Determine which class type we're copying to.
   Expr *CurInitExpr = (Expr *)CurInit.get();
-  CXXRecordDecl *Class = nullptr;
-  if (const RecordType *Record = T->getAs<RecordType>())
-    Class =
-        cast<CXXRecordDecl>(Record->getOriginalDecl())->getDefinitionOrSelf();
+  auto *Class = T->getAsCXXRecordDecl();
   if (!Class)
     return CurInit;
 
@@ -7328,7 +7290,7 @@ static void CheckCXX98CompatAccessibleCopy(Sema &S,
                                            Expr *CurInitExpr) {
   assert(S.getLangOpts().CPlusPlus11);
 
-  const RecordType *Record = CurInitExpr->getType()->getAs<RecordType>();
+  auto *Record = CurInitExpr->getType()->getAsCXXRecordDecl();
   if (!Record)
     return;
 
@@ -7338,8 +7300,7 @@ static void CheckCXX98CompatAccessibleCopy(Sema &S,
 
   // Find constructors which would have been considered.
   OverloadCandidateSet CandidateSet(Loc, OverloadCandidateSet::CSK_Normal);
-  DeclContext::lookup_result Ctors = S.LookupConstructors(
-      cast<CXXRecordDecl>(Record->getOriginalDecl())->getDefinitionOrSelf());
+  DeclContext::lookup_result Ctors = S.LookupConstructors(Record);
 
   // Perform overload resolution.
   OverloadCandidateSet::iterator Best;
@@ -7803,7 +7764,7 @@ ExprResult InitializationSequence::Perform(Sema &S,
         DiagID = diag::ext_default_init_const;
 
       S.Diag(Kind.getLocation(), DiagID)
-          << DestType << (bool)DestType->getAs<RecordType>()
+          << DestType << DestType->isRecordType()
           << FixItHint::CreateInsertion(ZeroInitializationFixitLoc,
                                         ZeroInitializationFixit);
     }
@@ -8169,10 +8130,8 @@ ExprResult InitializationSequence::Perform(Sema &S,
         // FIXME: It makes no sense to do this here. This should happen
         // regardless of how we initialized the entity.
         QualType T = CurInit.get()->getType();
-        if (const RecordType *Record = T->getAs<RecordType>()) {
-          CXXDestructorDecl *Destructor =
-              S.LookupDestructor(cast<CXXRecordDecl>(Record->getOriginalDecl())
-                                     ->getDefinitionOrSelf());
+        if (auto *Record = T->castAsCXXRecordDecl()) {
+          CXXDestructorDecl *Destructor = S.LookupDestructor(Record);
           S.CheckDestructorAccess(CurInit.get()->getBeginLoc(), Destructor,
                                   S.PDiag(diag::err_access_dtor_temp) << T);
           S.MarkFunctionReferenced(CurInit.get()->getBeginLoc(), Destructor);
@@ -8560,9 +8519,8 @@ ExprResult InitializationSequence::Perform(Sema &S,
             S.isStdInitializerList(Step->Type, &ElementType);
         assert(IsStdInitializerList &&
                "StdInitializerList step to non-std::initializer_list");
-        const CXXRecordDecl *Record =
-            Step->Type->getAsCXXRecordDecl()->getDefinition();
-        assert(Record && Record->isCompleteDefinition() &&
+        const auto *Record = Step->Type->castAsCXXRecordDecl();
+        assert(Record->isCompleteDefinition() &&
                "std::initializer_list should have already be "
                "complete/instantiated by this point");
 
@@ -9225,11 +9183,8 @@ bool InitializationSequence::Diagnose(Sema &S,
                 << S.Context.getCanonicalTagType(Constructor->getParent())
                 << /*base=*/0 << Entity.getType() << InheritedFrom;
 
-            RecordDecl *BaseDecl = Entity.getBaseSpecifier()
-                                       ->getType()
-                                       ->castAs<RecordType>()
-                                       ->getOriginalDecl()
-                                       ->getDefinitionOrSelf();
+            auto *BaseDecl =
+                Entity.getBaseSpecifier()->getType()->castAsRecordDecl();
             S.Diag(BaseDecl->getLocation(), diag::note_previous_decl)
                 << S.Context.getCanonicalTagType(BaseDecl);
           } else {
@@ -9242,8 +9197,7 @@ bool InitializationSequence::Diagnose(Sema &S,
             S.Diag(Entity.getDecl()->getLocation(),
                    diag::note_member_declared_at);
 
-            if (const RecordType *Record
-                                 = Entity.getType()->getAs<RecordType>())
+            if (const auto *Record = Entity.getType()->getAs<RecordType>())
               S.Diag(Record->getOriginalDecl()->getLocation(),
                      diag::note_previous_decl)
                   << S.Context.getCanonicalTagType(Record->getOriginalDecl());
@@ -9326,7 +9280,7 @@ bool InitializationSequence::Diagnose(Sema &S,
           << VD;
     } else {
       S.Diag(Kind.getLocation(), diag::err_default_init_const)
-          << DestType << (bool)DestType->getAs<RecordType>();
+          << DestType << DestType->isRecordType();
     }
     break;
 
@@ -10021,7 +9975,7 @@ QualType Sema::DeduceTemplateSpecializationFromInitializer(
         // dependent. e.g.
         //   using AliasFoo = Foo<bool>;
         if (const auto *CTSD = llvm::dyn_cast<ClassTemplateSpecializationDecl>(
-                RT->getAsCXXRecordDecl()))
+                RT->getOriginalDecl()))
           Template = CTSD->getSpecializedTemplate();
       }
     }
diff --git a/clang/lib/Sema/SemaLambda.cpp b/clang/lib/Sema/SemaLambda.cpp
index 8a81cbe2623d8..fbc2e7eb30676 100644
--- a/clang/lib/Sema/SemaLambda.cpp
+++ b/clang/lib/Sema/SemaLambda.cpp
@@ -641,9 +641,8 @@ static EnumDecl *findEnumForBlockReturn(Expr *E) {
   }
 
   //   - it is an expression of that formal enum type.
-  if (const EnumType *ET = E->getType()->getAs<EnumType>()) {
-    return ET->getOriginalDecl()->getDefinitionOrSelf();
-  }
+  if (auto *ED = E->getType()->getAsEnumDecl())
+    return ED;
 
   // Otherwise, nope.
   return nullptr;
diff --git a/clang/lib/Sema/SemaLookup.cpp b/clang/lib/Sema/SemaLookup.cpp
index e28492b579564..5ae82eaa50684 100644
--- a/clang/lib/Sema/SemaLookup.cpp
+++ b/clang/lib/Sema/SemaLookup.cpp
@@ -2770,10 +2770,7 @@ bool Sema::LookupInSuper(LookupResult &R, CXXRecordDecl *Class) {
   // members of Class itself.  That is, the naming class is Class, and the
   // access includes the access of the base.
   for (const auto &BaseSpec : Class->bases()) {
-    CXXRecordDecl *RD =
-        cast<CXXRecordDecl>(
-            BaseSpec.getType()->castAs<RecordType>()->getOriginalDecl())
-            ->getDefinitionOrSelf();
+    auto *RD = BaseSpec.getType()->castAsCXXRecordDecl();
     LookupResult Result(*this, R.getLookupNameInfo(), R.getLookupKind());
     Result.setBaseObjectType(Context.getCanonicalTagType(Class));
     LookupQualifiedName(Result, RD);
@@ -3114,17 +3111,15 @@ addAssociatedClassesAndNamespaces(AssociatedLookup &Result,
 
     // Visit the base classes.
     for (const auto &Base : Class->bases()) {
-      const RecordType *BaseType = Base.getType()->getAs<RecordType>();
+      CXXRecordDecl *BaseDecl = Base.getType()->getAsCXXRecordDecl();
       // In dependent contexts, we do ADL twice, and the first time around,
       // the base type might be a dependent TemplateSpecializationType, or a
       // TemplateTypeParmType. If that happens, simply ignore it.
       // FIXME: If we want to support export, we probably need to add the
       // namespace of the template in a TemplateSpecializationType, or even
       // the classes and namespaces of known non-dependent arguments.
-      if (!BaseType)
+      if (!BaseDecl)
         continue;
-      CXXRecordDecl *BaseDecl = cast<CXXRecordDecl>(BaseType->getOriginalDecl())
-                                    ->getDefinitionOrSelf();
       if (Result.addClassTransitive(BaseDecl)) {
         // Find the associated namespace for this base class.
         DeclContext *BaseCtx = BaseDecl->getDeclContext();
@@ -3209,8 +3204,7 @@ addAssociatedClassesAndNamespaces(AssociatedLookup &Result, QualType Ty) {
     //        member’s class; else it has no associated class.
     case Type::Enum: {
       // FIXME: This should use the original decl.
-      EnumDecl *Enum =
-          cast<EnumType>(T)->getOriginalDecl()->getDefinitionOrSelf();
+      auto *Enum = T->castAsEnumDecl();
 
       DeclContext *Ctx = Enum->getDeclContext();
       if (CXXRecordDecl *EnclosingClass = dyn_cast<CXXRecordDecl>(Ctx))
@@ -4262,10 +4256,9 @@ class LookupVisibleHelper {
             continue;
           RD = TD->getTemplatedDecl();
         } else {
-          const auto *Record = BaseType->getAs<RecordType>();
-          if (!Record)
+          RD = BaseType->getAsCXXRecordDecl();
+          if (!RD)
             continue;
-          RD = Record->getOriginalDecl()->getDefinitionOrSelf();
         }
 
         // FIXME: It would be nice to be able to determine whether referencing
diff --git a/clang/lib/Sema/SemaObjC.cpp b/clang/lib/Sema/SemaObjC.cpp
index 8d8d5e87afe73..4f9470a361d2d 100644
--- a/clang/lib/Sema/SemaObjC.cpp
+++ b/clang/lib/Sema/SemaObjC.cpp
@@ -1380,7 +1380,7 @@ SemaObjC::ObjCSubscriptKind SemaObjC::CheckSubscriptingKind(Expr *FromE) {
 
   // If we don't have a class type in C++, there's no way we can get an
   // expression of integral or enumeration type.
-  const RecordType *RecordTy = T->getAs<RecordType>();
+  const RecordType *RecordTy = T->getAsCanonical<RecordType>();
   if (!RecordTy && (T->isObjCObjectPointerType() || T->isVoidPointerType()))
     // All other scalar cases are assumed to be dictionary indexing which
     // caller handles, with diagnostics if needed.
@@ -1507,7 +1507,7 @@ bool SemaObjC::isCFStringType(QualType T) {
   if (!PT)
     return false;
 
-  const auto *RT = PT->getPointeeType()->getAs<RecordType>();
+  const auto *RT = PT->getPointeeType()->getAsCanonical<RecordType>();
   if (!RT)
     return false;
 
diff --git a/clang/lib/Sema/SemaObjCProperty.cpp b/clang/lib/Sema/SemaObjCProperty.cpp
index bf6c364e40cc4..1880cec6ec8e5 100644
--- a/clang/lib/Sema/SemaObjCProperty.cpp
+++ b/clang/lib/Sema/SemaObjCProperty.cpp
@@ -1320,10 +1320,8 @@ Decl *SemaObjC::ActOnPropertyImplDecl(
         CompleteTypeErr = true;
       }
       if (!CompleteTypeErr) {
-        const RecordType *RecordTy = PropertyIvarType->getAs<RecordType>();
-        if (RecordTy && RecordTy->getOriginalDecl()
-                            ->getDefinitionOrSelf()
-                            ->hasFlexibleArrayMember()) {
+        if (const auto *RD = PropertyIvarType->getAsRecordDecl();
+            RD && RD->hasFlexibleArrayMember()) {
           Diag(PropertyIvarLoc, diag::err_synthesize_variable_sized_ivar)
             << PropertyIvarType;
           CompleteTypeErr = true; // suppress later diagnostics about the ivar
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index 7d800c446b595..8f666cee72937 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -18628,12 +18628,12 @@ buildDeclareReductionRef(Sema &SemaRef, SourceLocation Loc, SourceRange Range,
   //        the set of member candidates is empty.
   LookupResult Lookup(SemaRef, ReductionId, Sema::LookupOMPReductionName);
   Lookup.suppressDiagnostics();
-  if (const auto *TyRec = Ty->getAs<RecordType>()) {
+  if (Ty->isRecordType()) {
     // Complete the type if it can be completed.
     // If the type is neither complete nor being defined, bail out now.
     bool IsComplete = SemaRef.isCompleteType(Loc, Ty);
-    RecordDecl *RD = TyRec->getOriginalDecl()->getDefinition();
-    if (IsComplete || RD) {
+    auto *RD = Ty->castAsRecordDecl();
+    if (IsComplete || RD->isBeingDefined()) {
       Lookup.clear();
       SemaRef.LookupQualifiedName(Lookup, RD);
       if (Lookup.empty()) {
diff --git a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp
index 7c4405b414c47..14fa8478fe317 100644
--- a/clang/lib/Sema/SemaOverload.cpp
+++ b/clang/lib/Sema/SemaOverload.cpp
@@ -369,8 +369,8 @@ NarrowingKind StandardConversionSequence::getNarrowingKind(
   // A conversion to an enumeration type is narrowing if the conversion to
   // the underlying type is narrowing. This only arises for expressions of
   // the form 'Enum{init}'.
-  if (auto *ET = ToType->getAs<EnumType>())
-    ToType = ET->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = ToType->getAsEnumDecl())
+    ToType = ED->getIntegerType();
 
   switch (Second) {
   // 'bool' is an integral type; dispatch to the right place to handle it.
@@ -1058,13 +1058,12 @@ static bool shouldAddReversedEqEq(Sema &S, SourceLocation OpLoc,
   if (isa<CXXMethodDecl>(EqFD)) {
     // If F is a class member, search scope is class type of first operand.
     QualType RHS = FirstOperand->getType();
-    auto *RHSRec = RHS->getAs<RecordType>();
+    auto *RHSRec = RHS->getAsCXXRecordDecl();
     if (!RHSRec)
       return true;
     LookupResult Members(S, NotEqOp, OpLoc,
                          Sema::LookupNameKind::LookupMemberName);
-    S.LookupQualifiedName(Members,
-                          RHSRec->getOriginalDecl()->getDefinitionOrSelf());
+    S.LookupQualifiedName(Members, RHSRec);
     Members.suppressAccessDiagnostics();
     for (NamedDecl *Op : Members)
       if (FunctionsCorrespond(S.Context, EqFD, Op->getAsFunction()))
@@ -1802,7 +1801,7 @@ TryImplicitConversion(Sema &S, Expr *From, QualType ToType,
   //   constructor (i.e., a user-defined conversion function) is
   //   called for those cases.
   QualType FromType = From->getType();
-  if (ToType->getAs<RecordType>() && FromType->getAs<RecordType>() &&
+  if (ToType->isRecordType() &&
       (S.Context.hasSameUnqualifiedType(FromType, ToType) ||
        S.IsDerivedFrom(From->getBeginLoc(), FromType, ToType))) {
     ICS.setStandard();
@@ -2662,11 +2661,9 @@ bool Sema::IsIntegralPromotion(Expr *From, QualType FromType, QualType ToType) {
   //   integral promotion can be applied to its underlying type, a prvalue of an
   //   unscoped enumeration type whose underlying type is fixed can also be
   //   converted to a prvalue of the promoted underlying type.
-  if (const EnumType *FromEnumType = FromType->getAs<EnumType>()) {
+  if (const auto *FromED = FromType->getAsEnumDecl()) {
     // C++0x 7.2p9: Note that this implicit enum to int conversion is not
     // provided for a scoped enumeration.
-    const EnumDecl *FromED =
-        FromEnumType->getOriginalDecl()->getDefinitionOrSelf();
     if (FromED->isScoped())
       return false;
 
@@ -3969,7 +3966,7 @@ IsUserDefinedConversion(Sema &S, Expr *From, QualType ToType,
 
   // If the type we are conversion to is a class type, enumerate its
   // constructors.
-  if (const RecordType *ToRecordType = ToType->getAs<RecordType>()) {
+  if (const RecordType *ToRecordType = ToType->getAsCanonical<RecordType>()) {
     // C++ [over.match.ctor]p1:
     //   When objects of class type are direct-initialized (8.5), or
     //   copy-initialized from an expression of the same or a
@@ -3979,7 +3976,7 @@ IsUserDefinedConversion(Sema &S, Expr *From, QualType ToType,
     //   that class. The argument list is the expression-list within
     //   the parentheses of the initializer.
     if (S.Context.hasSameUnqualifiedType(ToType, From->getType()) ||
-        (From->getType()->getAs<RecordType>() &&
+        (From->getType()->isRecordType() &&
          S.IsDerivedFrom(From->getBeginLoc(), From->getType(), ToType)))
       ConstructorsOnly = true;
 
@@ -4059,7 +4056,7 @@ IsUserDefinedConversion(Sema &S, Expr *From, QualType ToType,
   } else if (!S.isCompleteType(From->getBeginLoc(), From->getType())) {
     // No conversion functions from incomplete types.
   } else if (const RecordType *FromRecordType =
-                 From->getType()->getAs<RecordType>()) {
+                 From->getType()->getAsCanonical<RecordType>()) {
     if (auto *FromRecordDecl =
             dyn_cast<CXXRecordDecl>(FromRecordType->getOriginalDecl())) {
       FromRecordDecl = FromRecordDecl->getDefinitionOrSelf();
@@ -4509,12 +4506,10 @@ getFixedEnumPromtion(Sema &S, const StandardConversionSequence &SCS) {
   if (SCS.Second != ICK_Integral_Promotion)
     return FixedEnumPromotion::None;
 
-  QualType FromType = SCS.getFromType();
-  if (!FromType->isEnumeralType())
+  const auto *Enum = SCS.getFromType()->getAsEnumDecl();
+  if (!Enum)
     return FixedEnumPromotion::None;
 
-  EnumDecl *Enum =
-      FromType->castAs<EnumType>()->getOriginalDecl()->getDefinitionOrSelf();
   if (!Enum->isFixed())
     return FixedEnumPromotion::None;
 
@@ -5150,10 +5145,7 @@ FindConversionForRefInit(Sema &S, ImplicitConversionSequence &ICS,
                          Expr *Init, QualType T2, bool AllowRvalues,
                          bool AllowExplicit) {
   assert(T2->isRecordType() && "Can only find conversions of record types.");
-  auto *T2RecordDecl =
-      cast<CXXRecordDecl>(T2->castAs<RecordType>()->getOriginalDecl())
-          ->getDefinitionOrSelf();
-
+  auto *T2RecordDecl = T2->castAsCXXRecordDecl();
   OverloadCandidateSet CandidateSet(
       DeclLoc, OverloadCandidateSet::CSK_InitByUserDefinedConversion);
   const auto &Conversions = T2RecordDecl->getVisibleConversionFunctions();
@@ -6821,7 +6813,7 @@ ExprResult Sema::PerformContextualImplicitConversion(
 
   // We can only perform contextual implicit conversions on objects of class
   // type.
-  const RecordType *RecordTy = T->getAs<RecordType>();
+  const RecordType *RecordTy = T->getAsCanonical<RecordType>();
   if (!RecordTy || !getLangOpts().CPlusPlus) {
     if (!Converter.Suppress)
       Converter.diagnoseNoMatch(*this, Loc, T) << From->getSourceRange();
@@ -8341,10 +8333,7 @@ void Sema::AddConversionCandidate(
   QualType ObjectType = From->getType();
   if (const auto *FromPtrType = ObjectType->getAs<PointerType>())
     ObjectType = FromPtrType->getPointeeType();
-  const auto *ConversionContext =
-      cast<CXXRecordDecl>(ObjectType->castAs<RecordType>()->getOriginalDecl())
-          ->getDefinitionOrSelf();
-
+  const auto *ConversionContext = ObjectType->castAsCXXRecordDecl();
   // C++23 [over.best.ics.general]
   // However, if the target is [...]
   // - the object parameter of a user-defined conversion function
@@ -8742,10 +8731,9 @@ void Sema::AddMemberOperatorCandidates(OverloadedOperatorKind Op,
   //        defined, the set of member candidates is the result of the
   //        qualified lookup of T1::operator@ (13.3.1.1.1); otherwise,
   //        the set of member candidates is empty.
-  if (const RecordType *T1Rec = T1->getAs<RecordType>()) {
+  if (T1->isRecordType()) {
     bool IsComplete = isCompleteType(OpLoc, T1);
-    CXXRecordDecl *T1RD =
-        cast<CXXRecordDecl>(T1Rec->getOriginalDecl())->getDefinition();
+    auto *T1RD = T1->getAsCXXRecordDecl();
     // Complete the type if it can be completed.
     // If the type is neither complete nor being defined, bail out now.
     if (!T1RD || (!IsComplete && !T1RD->isBeingDefined()))
@@ -9054,8 +9042,8 @@ BuiltinCandidateTypeSet::AddTypesConvertedFrom(QualType Ty,
   Ty = Ty.getLocalUnqualifiedType();
 
   // Flag if we ever add a non-record type.
-  const RecordType *TyRec = Ty->getAs<RecordType>();
-  HasNonRecordTypes = HasNonRecordTypes || !TyRec;
+  bool TyIsRec = Ty->isRecordType();
+  HasNonRecordTypes = HasNonRecordTypes || !TyIsRec;
 
   // Flag if we encounter an arithmetic type.
   HasArithmeticOrEnumeralTypes =
@@ -9090,13 +9078,12 @@ BuiltinCandidateTypeSet::AddTypesConvertedFrom(QualType Ty,
     MatrixTypes.insert(Ty);
   } else if (Ty->isNullPtrType()) {
     HasNullPtrType = true;
-  } else if (AllowUserConversions && TyRec) {
+  } else if (AllowUserConversions && TyIsRec) {
     // No conversion functions in incomplete types.
     if (!SemaRef.isCompleteType(Loc, Ty))
       return;
 
-    auto *ClassDecl =
-        cast<CXXRecordDecl>(TyRec->getOriginalDecl())->getDefinitionOrSelf();
+    auto *ClassDecl = Ty->castAsCXXRecordDecl();
     for (NamedDecl *D : ClassDecl->getVisibleConversionFunctions()) {
       if (isa<UsingShadowDecl>(D))
         D = cast<UsingShadowDecl>(D)->getTargetDecl();
@@ -10155,7 +10142,7 @@ class BuiltinOperatorOverloadBuilder {
         continue;
       for (QualType MemPtrTy : CandidateTypes[1].member_pointer_types()) {
         const MemberPointerType *mptr = cast<MemberPointerType>(MemPtrTy);
-        CXXRecordDecl *D1 = C1->getAsCXXRecordDecl(),
+        CXXRecordDecl *D1 = C1->castAsCXXRecordDecl(),
                       *D2 = mptr->getMostRecentCXXRecordDecl();
         if (!declaresSameEntity(D1, D2) &&
             !S.IsDerivedFrom(CandidateSet.getLocation(), D1, D2))
@@ -10208,7 +10195,9 @@ class BuiltinOperatorOverloadBuilder {
 
       if (S.getLangOpts().CPlusPlus11) {
         for (QualType EnumTy : CandidateTypes[ArgIdx].enumeration_types()) {
-          if (!EnumTy->castAs<EnumType>()->getOriginalDecl()->isScoped())
+          if (!EnumTy->castAsCanonical<EnumType>()
+                   ->getOriginalDecl()
+                   ->isScoped())
             continue;
 
           if (!AddedTypes.insert(S.Context.getCanonicalType(EnumTy)).second)
@@ -16353,9 +16342,9 @@ Sema::BuildCallToObjectOfClassType(Scope *S, Expr *Obj,
                           diag::err_incomplete_object_call, Object.get()))
     return true;
 
-  const auto *Record = Object.get()->getType()->castAs<RecordType>();
+  auto *Record = Object.get()->getType()->castAsCXXRecordDecl();
   LookupResult R(*this, OpName, LParenLoc, LookupOrdinaryName);
-  LookupQualifiedName(R, Record->getOriginalDecl()->getDefinitionOrSelf());
+  LookupQualifiedName(R, Record);
   R.suppressAccessDiagnostics();
 
   for (LookupResult::iterator Oper = R.begin(), OperEnd = R.end();
@@ -16374,8 +16363,7 @@ Sema::BuildCallToObjectOfClassType(Scope *S, Expr *Obj,
   // we filter them out to produce better error diagnostics, ie to avoid
   // showing 2 failed overloads instead of one.
   bool IgnoreSurrogateFunctions = false;
-  if (CandidateSet.nonDeferredCandidatesCount() == 1 &&
-      Record->getAsCXXRecordDecl()->isLambda()) {
+  if (CandidateSet.nonDeferredCandidatesCount() == 1 && Record->isLambda()) {
     const OverloadCandidate &Candidate = *CandidateSet.begin();
     if (!Candidate.Viable &&
         Candidate.FailureKind == ovl_fail_constraints_not_satisfied)
@@ -16399,9 +16387,7 @@ Sema::BuildCallToObjectOfClassType(Scope *S, Expr *Obj,
   //   functions for each conversion function declared in an
   //   accessible base class provided the function is not hidden
   //   within T by another intervening declaration.
-  const auto &Conversions = cast<CXXRecordDecl>(Record->getOriginalDecl())
-                                ->getDefinitionOrSelf()
-                                ->getVisibleConversionFunctions();
+  const auto &Conversions = Record->getVisibleConversionFunctions();
   for (auto I = Conversions.begin(), E = Conversions.end();
        !IgnoreSurrogateFunctions && I != E; ++I) {
     NamedDecl *D = *I;
@@ -16622,10 +16608,7 @@ ExprResult Sema::BuildOverloadedArrowExpr(Scope *S, Expr *Base,
     return ExprError();
 
   LookupResult R(*this, OpName, OpLoc, LookupOrdinaryName);
-  LookupQualifiedName(R, Base->getType()
-                             ->castAs<RecordType>()
-                             ->getOriginalDecl()
-                             ->getDefinitionOrSelf());
+  LookupQualifiedName(R, Base->getType()->castAsRecordDecl());
   R.suppressAccessDiagnostics();
 
   for (LookupResult::iterator Oper = R.begin(), OperEnd = R.end();
diff --git a/clang/lib/Sema/SemaPPC.cpp b/clang/lib/Sema/SemaPPC.cpp
index 46d7372dd056b..bfa458d207b46 100644
--- a/clang/lib/Sema/SemaPPC.cpp
+++ b/clang/lib/Sema/SemaPPC.cpp
@@ -41,10 +41,7 @@ void SemaPPC::checkAIXMemberAlignment(SourceLocation Loc, const Expr *Arg) {
     return;
 
   QualType ArgType = Arg->getType();
-  for (const FieldDecl *FD : ArgType->castAs<RecordType>()
-                                 ->getOriginalDecl()
-                                 ->getDefinitionOrSelf()
-                                 ->fields()) {
+  for (const FieldDecl *FD : ArgType->castAsRecordDecl()->fields()) {
     if (const auto *AA = FD->getAttr<AlignedAttr>()) {
       CharUnits Alignment = getASTContext().toCharUnitsFromBits(
           AA->getAlignment(getASTContext()));
diff --git a/clang/lib/Sema/SemaStmt.cpp b/clang/lib/Sema/SemaStmt.cpp
index bc1ddb80961a2..5625fb359807a 100644
--- a/clang/lib/Sema/SemaStmt.cpp
+++ b/clang/lib/Sema/SemaStmt.cpp
@@ -1271,8 +1271,8 @@ static void checkEnumTypesInSwitchStmt(Sema &S, const Expr *Cond,
   QualType CondType = Cond->getType();
   QualType CaseType = Case->getType();
 
-  const EnumType *CondEnumType = CondType->getAs<EnumType>();
-  const EnumType *CaseEnumType = CaseType->getAs<EnumType>();
+  const EnumType *CondEnumType = CondType->getAsCanonical<EnumType>();
+  const EnumType *CaseEnumType = CaseType->getAsCanonical<EnumType>();
   if (!CondEnumType || !CaseEnumType)
     return;
 
@@ -1590,12 +1590,10 @@ Sema::ActOnFinishSwitchStmt(SourceLocation SwitchLoc, Stmt *Switch,
     // we still do the analysis to preserve this information in the AST
     // (which can be used by flow-based analyes).
     //
-    const EnumType *ET = CondTypeBeforePromotion->getAs<EnumType>();
-
     // If switch has default case, then ignore it.
     if (!CaseListIsErroneous && !CaseListIsIncomplete && !HasConstantCond &&
-        ET) {
-      const EnumDecl *ED = ET->getOriginalDecl()->getDefinitionOrSelf();
+        CondTypeBeforePromotion->isEnumeralType()) {
+      const auto *ED = CondTypeBeforePromotion->castAsEnumDecl();
       if (!ED->isCompleteDefinition() || ED->enumerators().empty())
         goto enum_out;
 
@@ -1730,8 +1728,7 @@ void
 Sema::DiagnoseAssignmentEnum(QualType DstType, QualType SrcType,
                              Expr *SrcExpr) {
 
-  const auto *ET = DstType->getAs<EnumType>();
-  if (!ET)
+  if (!DstType->isEnumeralType())
     return;
 
   if (!SrcType->isIntegerType() ||
@@ -1741,7 +1738,7 @@ Sema::DiagnoseAssignmentEnum(QualType DstType, QualType SrcType,
   if (SrcExpr->isTypeDependent() || SrcExpr->isValueDependent())
     return;
 
-  const EnumDecl *ED = ET->getOriginalDecl()->getDefinitionOrSelf();
+  const auto *ED = DstType->castAsEnumDecl();
   if (!ED->isClosed())
     return;
 
diff --git a/clang/lib/Sema/SemaStmtAsm.cpp b/clang/lib/Sema/SemaStmtAsm.cpp
index cd8b98c7444eb..13211007f358d 100644
--- a/clang/lib/Sema/SemaStmtAsm.cpp
+++ b/clang/lib/Sema/SemaStmtAsm.cpp
@@ -885,18 +885,19 @@ bool Sema::LookupInlineAsmField(StringRef Base, StringRef Member,
   for (StringRef NextMember : Members) {
     const RecordType *RT = nullptr;
     if (VarDecl *VD = dyn_cast<VarDecl>(FoundDecl))
-      RT = VD->getType()->getAs<RecordType>();
+      RT = VD->getType()->getAsCanonical<RecordType>();
     else if (TypedefNameDecl *TD = dyn_cast<TypedefNameDecl>(FoundDecl)) {
       MarkAnyDeclReferenced(TD->getLocation(), TD, /*OdrUse=*/false);
       // MS InlineAsm often uses struct pointer aliases as a base
       QualType QT = TD->getUnderlyingType();
       if (const auto *PT = QT->getAs<PointerType>())
         QT = PT->getPointeeType();
-      RT = QT->getAs<RecordType>();
+      RT = QT->getAsCanonical<RecordType>();
     } else if (TypeDecl *TD = dyn_cast<TypeDecl>(FoundDecl))
-      RT = Context.getTypeDeclType(TD)->getAs<RecordType>();
+      RT = QualType(Context.getCanonicalTypeDeclType(TD))
+               ->getAsCanonical<RecordType>();
     else if (FieldDecl *TD = dyn_cast<FieldDecl>(FoundDecl))
-      RT = TD->getType()->getAs<RecordType>();
+      RT = TD->getType()->getAsCanonical<RecordType>();
     if (!RT)
       return true;
 
@@ -944,16 +945,15 @@ Sema::LookupInlineAsmVarDeclField(Expr *E, StringRef Member,
         /*FirstQualifierFoundInScope=*/nullptr, NameInfo, /*TemplateArgs=*/nullptr);
   }
 
-  const RecordType *RT = T->getAs<RecordType>();
+  const auto *RD = T->getAsRecordDecl();
   // FIXME: Diagnose this as field access into a scalar type.
-  if (!RT)
+  if (!RD)
     return ExprResult();
 
   LookupResult FieldResult(*this, &Context.Idents.get(Member), AsmLoc,
                            LookupMemberName);
 
-  if (!LookupQualifiedName(FieldResult,
-                           RT->getOriginalDecl()->getDefinitionOrSelf()))
+  if (!LookupQualifiedName(FieldResult, RD->getDefinitionOrSelf()))
     return ExprResult();
 
   // Only normal and indirect field results will work.
diff --git a/clang/lib/Sema/SemaSwift.cpp b/clang/lib/Sema/SemaSwift.cpp
index a99222c5ed55f..d21d79344d5c7 100644
--- a/clang/lib/Sema/SemaSwift.cpp
+++ b/clang/lib/Sema/SemaSwift.cpp
@@ -129,9 +129,9 @@ static bool isErrorParameter(Sema &S, QualType QT) {
 
   // Check for CFError**.
   if (const auto *PT = Pointee->getAs<PointerType>())
-    if (const auto *RT = PT->getPointeeType()->getAs<RecordType>())
-      if (S.ObjC().isCFError(RT->getOriginalDecl()->getDefinitionOrSelf()))
-        return true;
+    if (auto *RD = PT->getPointeeType()->getAsRecordDecl();
+        RD && S.ObjC().isCFError(RD))
+      return true;
 
   return false;
 }
@@ -271,12 +271,10 @@ static void checkSwiftAsyncErrorBlock(Sema &S, Decl *D,
       }
       // Check for CFError *.
       if (const auto *PtrTy = Param->getAs<PointerType>()) {
-        if (const auto *RT = PtrTy->getPointeeType()->getAs<RecordType>()) {
-          if (S.ObjC().isCFError(
-                  RT->getOriginalDecl()->getDefinitionOrSelf())) {
-            AnyErrorParams = true;
-            break;
-          }
+        if (auto *RD = PtrTy->getPointeeType()->getAsRecordDecl();
+            RD && S.ObjC().isCFError(RD)) {
+          AnyErrorParams = true;
+          break;
         }
       }
     }
diff --git a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp
index 36bffc5e5e3c9..3dc77fa8ddea1 100644
--- a/clang/lib/Sema/SemaTemplate.cpp
+++ b/clang/lib/Sema/SemaTemplate.cpp
@@ -2859,14 +2859,14 @@ TemplateParameterList *Sema::MatchTemplateParametersToScopeSpecifier(
     }
 
     // Retrieve the parent of an enumeration type.
-    if (const EnumType *EnumT = T->getAs<EnumType>()) {
+    if (const EnumType *EnumT = T->getAsCanonical<EnumType>()) {
       // FIXME: Forward-declared enums require a TSK_ExplicitSpecialization
       // check here.
       EnumDecl *Enum = EnumT->getOriginalDecl();
 
       // Get to the parent type.
       if (TypeDecl *Parent = dyn_cast<TypeDecl>(Enum->getParent()))
-        T = Context.getTypeDeclType(Parent);
+        T = Context.getCanonicalTypeDeclType(Parent);
       else
         T = QualType();
       continue;
@@ -3313,7 +3313,7 @@ static bool isInVkNamespace(const RecordType *RT) {
 static SpirvOperand checkHLSLSpirvTypeOperand(Sema &SemaRef,
                                               QualType OperandArg,
                                               SourceLocation Loc) {
-  if (auto *RT = OperandArg->getAs<RecordType>()) {
+  if (auto *RT = OperandArg->getAsCanonical<RecordType>()) {
     bool Literal = false;
     SourceLocation LiteralLoc;
     if (isInVkNamespace(RT) && RT->getOriginalDecl()->getName() == "Literal") {
@@ -3323,7 +3323,7 @@ static SpirvOperand checkHLSLSpirvTypeOperand(Sema &SemaRef,
 
       const TemplateArgumentList &LiteralArgs = SpecDecl->getTemplateArgs();
       QualType ConstantType = LiteralArgs[0].getAsType();
-      RT = ConstantType->getAs<RecordType>();
+      RT = ConstantType->getAsCanonical<RecordType>();
       Literal = true;
       LiteralLoc = SpecDecl->getSourceRange().getBegin();
     }
@@ -4086,7 +4086,8 @@ TypeResult Sema::ActOnTagTemplateIdType(TagUseKind TUK,
     return TypeResult(true);
 
   // Check the tag kind
-  if (const RecordType *RT = Result->getAs<RecordType>()) {
+  // FIXME: This looks like it doesn't need canonical type and definition.
+  if (const RecordType *RT = Result->getAsCanonical<RecordType>()) {
     RecordDecl *D = RT->getOriginalDecl()->getDefinitionOrSelf();
 
     IdentifierInfo *Id = D->getIdentifier();
@@ -4132,7 +4133,7 @@ static bool isTemplateArgumentTemplateParameter(const TemplateArgument &Arg,
   case TemplateArgument::Type: {
     QualType Type = Arg.getAsType();
     const TemplateTypeParmType *TPT =
-        Arg.getAsType()->getAs<TemplateTypeParmType>();
+        Arg.getAsType()->getAsCanonical<TemplateTypeParmType>();
     return TPT && !Type.hasQualifiers() &&
            TPT->getDepth() == Depth && TPT->getIndex() == Index;
   }
@@ -7411,9 +7412,8 @@ ExprResult Sema::CheckTemplateArgument(NamedDecl *Param, QualType ParamType,
       // always a no-op, except when the parameter type is bool. In
       // that case, this may extend the argument from 1 bit to 8 bits.
       QualType IntegerType = ParamType;
-      if (const EnumType *Enum = IntegerType->getAs<EnumType>())
-        IntegerType =
-            Enum->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+      if (const auto *ED = IntegerType->getAsEnumDecl())
+        IntegerType = ED->getIntegerType();
       Value = Value.extOrTrunc(IntegerType->isBitIntType()
                                    ? Context.getIntWidth(IntegerType)
                                    : Context.getTypeSize(IntegerType));
@@ -7510,9 +7510,8 @@ ExprResult Sema::CheckTemplateArgument(NamedDecl *Param, QualType ParamType,
     }
 
     QualType IntegerType = ParamType;
-    if (const EnumType *Enum = IntegerType->getAs<EnumType>()) {
-      IntegerType =
-          Enum->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+    if (const auto *ED = IntegerType->getAsEnumDecl()) {
+      IntegerType = ED->getIntegerType();
     }
 
     if (ParamType->isBooleanType()) {
@@ -8028,8 +8027,8 @@ static Expr *BuildExpressionFromIntegralTemplateArgumentValue(
   // any integral type with C++11 enum classes, make sure we create the right
   // type of literal for it.
   QualType T = OrigT;
-  if (const EnumType *ET = OrigT->getAs<EnumType>())
-    T = ET->getOriginalDecl()->getDefinitionOrSelf()->getIntegerType();
+  if (const auto *ED = OrigT->getAsEnumDecl())
+    T = ED->getIntegerType();
 
   Expr *E;
   if (T->isAnyCharacterType()) {
diff --git a/clang/lib/Sema/SemaTemplateDeduction.cpp b/clang/lib/Sema/SemaTemplateDeduction.cpp
index 93983bf6d1607..cce40c0c91f95 100644
--- a/clang/lib/Sema/SemaTemplateDeduction.cpp
+++ b/clang/lib/Sema/SemaTemplateDeduction.cpp
@@ -698,9 +698,8 @@ DeduceTemplateSpecArguments(Sema &S, TemplateParameterList *TemplateParams,
     TNP = TP->getTemplateName();
     // FIXME: To preserve sugar, the TST needs to carry sugared resolved
     // arguments.
-    PResolved = TP->getCanonicalTypeInternal()
-                    ->castAs<TemplateSpecializationType>()
-                    ->template_arguments();
+    PResolved =
+        TP->castAsCanonical<TemplateSpecializationType>()->template_arguments();
   } else {
     const auto *TT = P->castAs<InjectedClassNameType>();
     TNP = TT->getTemplateName(S.Context);
@@ -1437,7 +1436,8 @@ static bool isForwardingReference(QualType Param, unsigned FirstInnerIndex) {
   if (auto *ParamRef = Param->getAs<RValueReferenceType>()) {
     if (ParamRef->getPointeeType().getQualifiers())
       return false;
-    auto *TypeParm = ParamRef->getPointeeType()->getAs<TemplateTypeParmType>();
+    auto *TypeParm =
+        ParamRef->getPointeeType()->getAsCanonical<TemplateTypeParmType>();
     return TypeParm && TypeParm->getIndex() >= FirstInnerIndex;
   }
   return false;
@@ -1702,7 +1702,7 @@ static TemplateDeductionResult DeduceTemplateArgumentsByTypeMatch(
   //
   //     T
   //     cv-list T
-  if (const auto *TTP = P->getAs<TemplateTypeParmType>()) {
+  if (const auto *TTP = P->getAsCanonical<TemplateTypeParmType>()) {
     // Just skip any attempts to deduce from a placeholder type or a parameter
     // at a different depth.
     if (A->isPlaceholderType() || Info.getDeducedDepth() != TTP->getDepth())
@@ -3559,7 +3559,7 @@ static bool isSimpleTemplateIdType(QualType T) {
   //
   // This only arises during class template argument deduction for a copy
   // deduction candidate, where it permits slicing.
-  if (T->getAs<InjectedClassNameType>())
+  if (isa<InjectedClassNameType>(T.getCanonicalType()))
     return true;
 
   return false;
@@ -5596,7 +5596,7 @@ static TemplateDeductionResult CheckDeductionConsistency(
   // so let it transform their specializations instead.
   bool IsDeductionGuide = isa<CXXDeductionGuideDecl>(FTD->getTemplatedDecl());
   if (IsDeductionGuide) {
-    if (auto *Injected = P->getAs<InjectedClassNameType>())
+    if (auto *Injected = P->getAsCanonical<InjectedClassNameType>())
       P = Injected->getOriginalDecl()->getCanonicalTemplateSpecializationType(
           S.Context);
   }
@@ -5617,10 +5617,10 @@ static TemplateDeductionResult CheckDeductionConsistency(
   auto T1 = S.Context.getUnqualifiedArrayType(InstP.getNonReferenceType());
   auto T2 = S.Context.getUnqualifiedArrayType(A.getNonReferenceType());
   if (IsDeductionGuide) {
-    if (auto *Injected = T1->getAs<InjectedClassNameType>())
+    if (auto *Injected = T1->getAsCanonical<InjectedClassNameType>())
       T1 = Injected->getOriginalDecl()->getCanonicalTemplateSpecializationType(
           S.Context);
-    if (auto *Injected = T2->getAs<InjectedClassNameType>())
+    if (auto *Injected = T2->getAsCanonical<InjectedClassNameType>())
       T2 = Injected->getOriginalDecl()->getCanonicalTemplateSpecializationType(
           S.Context);
   }
diff --git a/clang/lib/Sema/SemaTemplateDeductionGuide.cpp b/clang/lib/Sema/SemaTemplateDeductionGuide.cpp
index 604591408728c..3d54d1eb4373a 100644
--- a/clang/lib/Sema/SemaTemplateDeductionGuide.cpp
+++ b/clang/lib/Sema/SemaTemplateDeductionGuide.cpp
@@ -995,8 +995,8 @@ getRHSTemplateDeclAndArgs(Sema &SemaRef, TypeAliasTemplateDecl *AliasTemplate) {
     // Cases where template arguments in the RHS of the alias are not
     // dependent. e.g.
     //   using AliasFoo = Foo<bool>;
-    if (const auto *CTSD = llvm::dyn_cast<ClassTemplateSpecializationDecl>(
-            RT->getAsCXXRecordDecl())) {
+    if (const auto *CTSD =
+            dyn_cast<ClassTemplateSpecializationDecl>(RT->getOriginalDecl())) {
       Template = CTSD->getSpecializedTemplate();
       AliasRhsTemplateArgs = CTSD->getTemplateArgs().asArray();
     }
@@ -1056,7 +1056,7 @@ BuildDeductionGuideForTypeAlias(Sema &SemaRef,
   // The (trailing) return type of the deduction guide.
   const TemplateSpecializationType *FReturnType =
       RType->getAs<TemplateSpecializationType>();
-  if (const auto *ICNT = RType->getAs<InjectedClassNameType>())
+  if (const auto *ICNT = RType->getAsCanonical<InjectedClassNameType>())
     // implicitly-generated deduction guide.
     FReturnType = cast<TemplateSpecializationType>(
         ICNT->getOriginalDecl()->getCanonicalTemplateSpecializationType(
diff --git a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
index 681ee796440f3..ee1b520fa46e9 100644
--- a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
+++ b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
@@ -6973,19 +6973,15 @@ NamedDecl *Sema::FindInstantiatedDecl(SourceLocation Loc, NamedDecl *D,
     // If our context used to be dependent, we may need to instantiate
     // it before performing lookup into that context.
     bool IsBeingInstantiated = false;
-    if (CXXRecordDecl *Spec = dyn_cast<CXXRecordDecl>(ParentDC)) {
+    if (auto *Spec = dyn_cast<CXXRecordDecl>(ParentDC)) {
       if (!Spec->isDependentContext()) {
-        CanQualType T = Context.getCanonicalTagType(Spec);
-        const RecordType *Tag = T->getAs<RecordType>();
-        assert(Tag && "type of non-dependent record is not a RecordType");
-        auto *TagDecl =
-            cast<CXXRecordDecl>(Tag->getOriginalDecl())->getDefinitionOrSelf();
-        if (TagDecl->isBeingDefined())
+        if (Spec->isEntityBeingDefined())
           IsBeingInstantiated = true;
-        else if (RequireCompleteType(Loc, T, diag::err_incomplete_type))
+        else if (RequireCompleteType(Loc, Context.getCanonicalTagType(Spec),
+                                     diag::err_incomplete_type))
           return nullptr;
 
-        ParentDC = TagDecl;
+        ParentDC = Spec->getDefinitionOrSelf();
       }
     }
 
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index d745cdbf0526f..3f31a05d382a6 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -2115,12 +2115,10 @@ QualType Sema::BuildArrayType(QualType T, ArraySizeModifier ASM,
     return QualType();
   }
 
-  if (const RecordType *EltTy = T->getAs<RecordType>()) {
+  if (const auto *RD = T->getAsRecordDecl()) {
     // If the element type is a struct or union that contains a variadic
     // array, accept it as a GNU extension: C99 6.7.2.1p2.
-    if (EltTy->getOriginalDecl()
-            ->getDefinitionOrSelf()
-            ->hasFlexibleArrayMember())
+    if (RD->hasFlexibleArrayMember())
       Diag(Loc, diag::ext_flexible_array_in_array) << T;
   } else if (T->isObjCObjectType()) {
     Diag(Loc, diag::err_objc_array_of_interfaces) << T;
@@ -3975,10 +3973,7 @@ classifyPointerDeclarator(Sema &S, QualType type, Declarator &declarator,
     if (numNormalPointers == 0)
       return PointerDeclaratorKind::NonPointer;
 
-    if (auto recordType = type->getAs<RecordType>()) {
-      RecordDecl *recordDecl =
-          recordType->getOriginalDecl()->getDefinitionOrSelf();
-
+    if (auto *recordDecl = type->getAsRecordDecl()) {
       // If this is CFErrorRef*, report it as such.
       if (numNormalPointers == 2 && numTypeSpecifierPointers < 2 &&
           S.ObjC().isCFError(recordDecl)) {
@@ -9622,19 +9617,16 @@ bool Sema::RequireLiteralType(SourceLocation Loc, QualType T,
   if (T->isVariableArrayType())
     return true;
 
-  const RecordType *RT = ElemType->getAs<RecordType>();
-  if (!RT)
+  if (!ElemType->isRecordType())
     return true;
 
-  const CXXRecordDecl *RD =
-      cast<CXXRecordDecl>(RT->getOriginalDecl())->getDefinitionOrSelf();
-
   // A partially-defined class type can't be a literal type, because a literal
   // class type must have a trivial destructor (which can't be checked until
   // the class definition is complete).
   if (RequireCompleteType(Loc, ElemType, diag::note_non_literal_incomplete, T))
     return true;
 
+  const auto *RD = ElemType->castAsCXXRecordDecl();
   // [expr.prim.lambda]p3:
   //   This class type is [not] a literal type.
   if (RD->isLambda() && !getLangOpts().CPlusPlus17) {
diff --git a/clang/lib/Sema/SemaTypeTraits.cpp b/clang/lib/Sema/SemaTypeTraits.cpp
index 0b4d5916f8dc3..8686bd94f1a2d 100644
--- a/clang/lib/Sema/SemaTypeTraits.cpp
+++ b/clang/lib/Sema/SemaTypeTraits.cpp
@@ -556,13 +556,11 @@ static bool CheckUnaryTypeTraitTypeCompleteness(Sema &S, TypeTrait UTT,
   }
 }
 
-static bool HasNoThrowOperator(const RecordType *RT, OverloadedOperatorKind Op,
+static bool HasNoThrowOperator(CXXRecordDecl *RD, OverloadedOperatorKind Op,
                                Sema &Self, SourceLocation KeyLoc, ASTContext &C,
                                bool (CXXRecordDecl::*HasTrivial)() const,
                                bool (CXXRecordDecl::*HasNonTrivial)() const,
                                bool (CXXMethodDecl::*IsDesiredOp)() const) {
-  CXXRecordDecl *RD =
-      cast<CXXRecordDecl>(RT->getOriginalDecl())->getDefinitionOrSelf();
   if ((RD->*HasTrivial)() && !(RD->*HasNonTrivial)())
     return true;
 
@@ -1007,8 +1005,8 @@ static bool EvaluateUnaryTypeTrait(Sema &Self, TypeTrait UTT,
     if (T.isPODType(C) || T->isObjCLifetimeType())
       return true;
 
-    if (const RecordType *RT = T->getAs<RecordType>())
-      return HasNoThrowOperator(RT, OO_Equal, Self, KeyLoc, C,
+    if (auto *RD = T->getAsCXXRecordDecl())
+      return HasNoThrowOperator(RD, OO_Equal, Self, KeyLoc, C,
                                 &CXXRecordDecl::hasTrivialCopyAssignment,
                                 &CXXRecordDecl::hasNonTrivialCopyAssignment,
                                 &CXXMethodDecl::isCopyAssignmentOperator);
@@ -1020,8 +1018,8 @@ static bool EvaluateUnaryTypeTrait(Sema &Self, TypeTrait UTT,
     if (T.isPODType(C))
       return true;
 
-    if (const RecordType *RT = C.getBaseElementType(T)->getAs<RecordType>())
-      return HasNoThrowOperator(RT, OO_Equal, Self, KeyLoc, C,
+    if (auto *RD = C.getBaseElementType(T)->getAsCXXRecordDecl())
+      return HasNoThrowOperator(RD, OO_Equal, Self, KeyLoc, C,
                                 &CXXRecordDecl::hasTrivialMoveAssignment,
                                 &CXXRecordDecl::hasNonTrivialMoveAssignment,
                                 &CXXMethodDecl::isMoveAssignmentOperator);
@@ -1585,8 +1583,8 @@ bool Sema::BuiltinIsBaseOf(SourceLocation RhsTLoc, QualType LhsT,
   // Base and Derived are not unions and name the same class type without
   // regard to cv-qualifiers.
 
-  const RecordType *lhsRecord = LhsT->getAs<RecordType>();
-  const RecordType *rhsRecord = RhsT->getAs<RecordType>();
+  const RecordType *lhsRecord = LhsT->getAsCanonical<RecordType>();
+  const RecordType *rhsRecord = RhsT->getAsCanonical<RecordType>();
   if (!rhsRecord || !lhsRecord) {
     const ObjCObjectType *LHSObjTy = LhsT->getAs<ObjCObjectType>();
     const ObjCObjectType *RHSObjTy = RhsT->getAs<ObjCObjectType>();
@@ -1645,8 +1643,8 @@ static bool EvaluateBinaryTypeTrait(Sema &Self, TypeTrait BTT,
     return Self.BuiltinIsBaseOf(Rhs->getTypeLoc().getBeginLoc(), LhsT, RhsT);
 
   case BTT_IsVirtualBaseOf: {
-    const RecordType *BaseRecord = LhsT->getAs<RecordType>();
-    const RecordType *DerivedRecord = RhsT->getAs<RecordType>();
+    const RecordType *BaseRecord = LhsT->getAsCanonical<RecordType>();
+    const RecordType *DerivedRecord = RhsT->getAsCanonical<RecordType>();
 
     if (!BaseRecord || !DerivedRecord) {
       DiagnoseVLAInCXXTypeTrait(Self, Lhs,
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index 1d14ead778446..885b2aa2c8f94 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -14380,9 +14380,9 @@ TreeTransform<Derived>::TransformCXXTypeidExpr(CXXTypeidExpr *E) {
   Expr *Op = E->getExprOperand();
   auto EvalCtx = Sema::ExpressionEvaluationContext::Unevaluated;
   if (E->isGLValue())
-    if (auto *RecordT = Op->getType()->getAs<RecordType>())
-      if (cast<CXXRecordDecl>(RecordT->getOriginalDecl())->isPolymorphic())
-        EvalCtx = SemaRef.ExprEvalContexts.back().Context;
+    if (auto *RD = Op->getType()->getAsCXXRecordDecl();
+        RD && RD->isPolymorphic())
+      EvalCtx = SemaRef.ExprEvalContexts.back().Context;
 
   EnterExpressionEvaluationContext Unevaluated(SemaRef, EvalCtx,
                                                Sema::ReuseLambdaContextDecl);
@@ -14623,12 +14623,9 @@ TreeTransform<Derived>::TransformCXXNewExpr(CXXNewExpr *E) {
     if (E->isArray() && !E->getAllocatedType()->isDependentType()) {
       QualType ElementType
         = SemaRef.Context.getBaseElementType(E->getAllocatedType());
-      if (const RecordType *RecordT = ElementType->getAs<RecordType>()) {
-        CXXRecordDecl *Record = cast<CXXRecordDecl>(RecordT->getOriginalDecl())
-                                    ->getDefinitionOrSelf();
-        if (CXXDestructorDecl *Destructor = SemaRef.LookupDestructor(Record)) {
+      if (CXXRecordDecl *Record = ElementType->getAsCXXRecordDecl()) {
+        if (CXXDestructorDecl *Destructor = SemaRef.LookupDestructor(Record))
           SemaRef.MarkFunctionReferenced(E->getBeginLoc(), Destructor);
-        }
       }
     }
 
@@ -14694,13 +14691,9 @@ TreeTransform<Derived>::TransformCXXDeleteExpr(CXXDeleteExpr *E) {
     if (!E->getArgument()->isTypeDependent()) {
       QualType Destroyed = SemaRef.Context.getBaseElementType(
                                                          E->getDestroyedType());
-      if (const RecordType *DestroyedRec = Destroyed->getAs<RecordType>()) {
-        CXXRecordDecl *Record =
-            cast<CXXRecordDecl>(DestroyedRec->getOriginalDecl())
-                ->getDefinitionOrSelf();
+      if (auto *Record = Destroyed->getAsCXXRecordDecl())
         SemaRef.MarkFunctionReferenced(E->getBeginLoc(),
                                        SemaRef.LookupDestructor(Record));
-      }
     }
 
     return E;
@@ -17650,12 +17643,13 @@ TreeTransform<Derived>::RebuildCXXPseudoDestructorExpr(Expr *Base,
                                                        SourceLocation CCLoc,
                                                        SourceLocation TildeLoc,
                                         PseudoDestructorTypeStorage Destroyed) {
-  QualType BaseType = Base->getType();
+  QualType CanonicalBaseType = Base->getType().getCanonicalType();
   if (Base->isTypeDependent() || Destroyed.getIdentifier() ||
-      (!isArrow && !BaseType->getAs<RecordType>()) ||
-      (isArrow && BaseType->getAs<PointerType>() &&
-       !BaseType->castAs<PointerType>()->getPointeeType()
-                                              ->template getAs<RecordType>())){
+      (!isArrow && !isa<RecordType>(CanonicalBaseType)) ||
+      (isArrow && isa<PointerType>(CanonicalBaseType) &&
+       !cast<PointerType>(CanonicalBaseType)
+            ->getPointeeType()
+            ->getAsCanonical<RecordType>())) {
     // This pseudo-destructor expression is still a pseudo-destructor.
     return SemaRef.BuildPseudoDestructorExpr(
         Base, OperatorLoc, isArrow ? tok::arrow : tok::period, SS, ScopeType,
@@ -17682,13 +17676,11 @@ TreeTransform<Derived>::RebuildCXXPseudoDestructorExpr(Expr *Base,
   }
 
   SourceLocation TemplateKWLoc; // FIXME: retrieve it from caller.
-  return getSema().BuildMemberReferenceExpr(Base, BaseType,
-                                            OperatorLoc, isArrow,
-                                            SS, TemplateKWLoc,
-                                            /*FIXME: FirstQualifier*/ nullptr,
-                                            NameInfo,
-                                            /*TemplateArgs*/ nullptr,
-                                            /*S*/nullptr);
+  return getSema().BuildMemberReferenceExpr(
+      Base, Base->getType(), OperatorLoc, isArrow, SS, TemplateKWLoc,
+      /*FIXME: FirstQualifier*/ nullptr, NameInfo,
+      /*TemplateArgs*/ nullptr,
+      /*S*/ nullptr);
 }
 
 template<typename Derived>
diff --git a/clang/lib/Sema/UsedDeclVisitor.h b/clang/lib/Sema/UsedDeclVisitor.h
index ad475ab0f42ae..546b3a9e66e21 100644
--- a/clang/lib/Sema/UsedDeclVisitor.h
+++ b/clang/lib/Sema/UsedDeclVisitor.h
@@ -70,12 +70,10 @@ class UsedDeclVisitor : public EvaluatedExprVisitor<Derived> {
     QualType DestroyedOrNull = E->getDestroyedType();
     if (!DestroyedOrNull.isNull()) {
       QualType Destroyed = S.Context.getBaseElementType(DestroyedOrNull);
-      if (const RecordType *DestroyedRec = Destroyed->getAs<RecordType>()) {
-        CXXRecordDecl *Record =
-            cast<CXXRecordDecl>(DestroyedRec->getOriginalDecl());
-        if (auto *Def = Record->getDefinition())
-          asImpl().visitUsedDecl(E->getBeginLoc(), S.LookupDestructor(Def));
-      }
+      if (auto *Record = Destroyed->getAsCXXRecordDecl();
+          Record &&
+          (Record->isBeingDefined() || Record->isCompleteDefinition()))
+        asImpl().visitUsedDecl(E->getBeginLoc(), S.LookupDestructor(Record));
     }
 
     Inherited::VisitCXXDeleteExpr(E);
diff --git a/clang/lib/StaticAnalyzer/Checkers/CastSizeChecker.cpp b/clang/lib/StaticAnalyzer/Checkers/CastSizeChecker.cpp
index 90c6537d71d9d..781216da77f73 100644
--- a/clang/lib/StaticAnalyzer/Checkers/CastSizeChecker.cpp
+++ b/clang/lib/StaticAnalyzer/Checkers/CastSizeChecker.cpp
@@ -49,11 +49,10 @@ class CastSizeChecker : public Checker< check::PreStmt<CastExpr> > {
 /// of struct bar.
 static bool evenFlexibleArraySize(ASTContext &Ctx, CharUnits RegionSize,
                                   CharUnits TypeSize, QualType ToPointeeTy) {
-  const RecordType *RT = ToPointeeTy->getAs<RecordType>();
-  if (!RT)
+  const auto *RD = ToPointeeTy->getAsRecordDecl();
+  if (!RD)
     return false;
 
-  const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
   RecordDecl::field_iterator Iter(RD->field_begin());
   RecordDecl::field_iterator End(RD->field_end());
   const FieldDecl *Last = nullptr;
diff --git a/clang/lib/StaticAnalyzer/Checkers/EnumCastOutOfRangeChecker.cpp b/clang/lib/StaticAnalyzer/Checkers/EnumCastOutOfRangeChecker.cpp
index 054b2e96bd13b..76a1470aaac44 100644
--- a/clang/lib/StaticAnalyzer/Checkers/EnumCastOutOfRangeChecker.cpp
+++ b/clang/lib/StaticAnalyzer/Checkers/EnumCastOutOfRangeChecker.cpp
@@ -139,18 +139,11 @@ void EnumCastOutOfRangeChecker::checkPreStmt(const CastExpr *CE,
   if (!ValueToCast)
     return;
 
-  const QualType T = CE->getType();
   // Check whether the cast type is an enum.
-  if (!T->isEnumeralType())
+  const auto *ED = CE->getType()->getAsEnumDecl();
+  if (!ED)
     return;
 
-  // If the cast is an enum, get its declaration.
-  // If the isEnumeralType() returned true, then the declaration must exist
-  // even if it is a stub declaration. It is up to the getDeclValuesForEnum()
-  // function to handle this.
-  const EnumDecl *ED =
-      T->castAs<EnumType>()->getOriginalDecl()->getDefinitionOrSelf();
-
   // [[clang::flag_enum]] annotated enums are by definition should be ignored.
   if (ED->hasAttr<FlagEnumAttr>())
     return;
diff --git a/clang/lib/StaticAnalyzer/Checkers/LLVMConventionsChecker.cpp b/clang/lib/StaticAnalyzer/Checkers/LLVMConventionsChecker.cpp
index 828b6f91d81c2..c0727ae5dc8ba 100644
--- a/clang/lib/StaticAnalyzer/Checkers/LLVMConventionsChecker.cpp
+++ b/clang/lib/StaticAnalyzer/Checkers/LLVMConventionsChecker.cpp
@@ -27,7 +27,7 @@ using namespace ento;
 //===----------------------------------------------------------------------===//
 
 static bool IsLLVMStringRef(QualType T) {
-  const RecordType *RT = T->getAs<RecordType>();
+  const RecordType *RT = T->getAsCanonical<RecordType>();
   if (!RT)
     return false;
 
@@ -195,14 +195,10 @@ static bool IsPartOfAST(const CXXRecordDecl *R) {
   if (IsClangStmt(R) || IsClangType(R) || IsClangDecl(R) || IsClangAttr(R))
     return true;
 
-  for (const auto &BS : R->bases()) {
-    QualType T = BS.getType();
-    if (const RecordType *baseT = T->getAs<RecordType>()) {
-      CXXRecordDecl *baseD = cast<CXXRecordDecl>(baseT->getOriginalDecl());
-      if (IsPartOfAST(baseD))
-        return true;
-    }
-  }
+  for (const auto &BS : R->bases())
+    if (const auto *baseD = BS.getType()->getAsCXXRecordDecl();
+        baseD && IsPartOfAST(baseD))
+      return true;
 
   return false;
 }
@@ -243,11 +239,9 @@ void ASTFieldVisitor::Visit(FieldDecl *D) {
   if (AllocatesMemory(T))
     ReportError(T);
 
-  if (const RecordType *RT = T->getAs<RecordType>()) {
-    const RecordDecl *RD = RT->getOriginalDecl()->getDefinition();
+  if (const auto *RD = T->getAsRecordDecl())
     for (auto *I : RD->fields())
       Visit(I);
-  }
 
   FieldChain.pop_back();
 }
diff --git a/clang/lib/StaticAnalyzer/Checkers/PaddingChecker.cpp b/clang/lib/StaticAnalyzer/Checkers/PaddingChecker.cpp
index 7ef659518ab1b..1554604b374ca 100644
--- a/clang/lib/StaticAnalyzer/Checkers/PaddingChecker.cpp
+++ b/clang/lib/StaticAnalyzer/Checkers/PaddingChecker.cpp
@@ -118,12 +118,12 @@ class PaddingChecker : public Checker<check::ASTDecl<TranslationUnitDecl>> {
       Elts = CArrTy->getZExtSize();
     if (Elts == 0)
       return;
-    const RecordType *RT = ArrTy->getElementType()->getAs<RecordType>();
-    if (RT == nullptr)
+    const auto *RD = ArrTy->getElementType()->getAsRecordDecl();
+    if (!RD)
       return;
 
     // TODO: Recurse into the fields to see if they have excess padding.
-    visitRecord(RT->getOriginalDecl()->getDefinitionOrSelf(), Elts);
+    visitRecord(RD, Elts);
   }
 
   bool shouldSkipDecl(const RecordDecl *RD) const {
diff --git a/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp b/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
index 36c12582a5787..884dbe90e7b12 100644
--- a/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
+++ b/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
@@ -241,7 +241,7 @@ void RetainTypeChecker::visitTypedef(const TypedefDecl *TD) {
     return;
 
   auto PointeeQT = QT->getPointeeType();
-  const RecordType *RT = PointeeQT->getAs<RecordType>();
+  const RecordType *RT = PointeeQT->getAsCanonical<RecordType>();
   if (!RT) {
     if (TD->hasAttr<ObjCBridgeAttr>() || TD->hasAttr<ObjCBridgeMutableAttr>()) {
       RecordlessTypes.insert(TD->getASTContext()
diff --git a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp
index c853c00019c10..785cdfa15bf04 100644
--- a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp
+++ b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp
@@ -3165,7 +3165,7 @@ void ExprEngine::processSwitch(SwitchNodeBuilder& builder) {
   // feasible then it shouldn't be considered for making 'default:' reachable.
   const SwitchStmt *SS = builder.getSwitch();
   const Expr *CondExpr = SS->getCond()->IgnoreParenImpCasts();
-  if (CondExpr->getType()->getAs<EnumType>()) {
+  if (CondExpr->getType()->isEnumeralType()) {
     if (SS->isAllEnumCasesCovered())
       return;
   }
diff --git a/clang/lib/StaticAnalyzer/Core/RegionStore.cpp b/clang/lib/StaticAnalyzer/Core/RegionStore.cpp
index 02375b0c3469a..8f18533af68b9 100644
--- a/clang/lib/StaticAnalyzer/Core/RegionStore.cpp
+++ b/clang/lib/StaticAnalyzer/Core/RegionStore.cpp
@@ -2454,7 +2454,7 @@ NonLoc RegionStoreManager::createLazyBinding(RegionBindingsConstRef B,
 SVal RegionStoreManager::getBindingForStruct(RegionBindingsConstRef B,
                                              const TypedValueRegion *R) {
   const RecordDecl *RD =
-      R->getValueType()->castAs<RecordType>()->getOriginalDecl();
+      R->getValueType()->castAsCanonical<RecordType>()->getOriginalDecl();
   if (!RD->getDefinition())
     return UnknownVal();
 
@@ -2844,9 +2844,7 @@ RegionStoreManager::bindStruct(LimitedRegionBindingsConstRef B,
   QualType T = R->getValueType();
   assert(T->isStructureOrClassType());
 
-  const RecordType* RT = T->castAs<RecordType>();
-  const RecordDecl *RD = RT->getOriginalDecl()->getDefinitionOrSelf();
-
+  const auto *RD = T->castAsRecordDecl();
   if (!RD->isCompleteDefinition())
     return B;
 
diff --git a/clang/test/CXX/drs/cwg0xx.cpp b/clang/test/CXX/drs/cwg0xx.cpp
index 9aeee7a76903d..805be67f2dc1a 100644
--- a/clang/test/CXX/drs/cwg0xx.cpp
+++ b/clang/test/CXX/drs/cwg0xx.cpp
@@ -244,7 +244,7 @@ namespace cwg16 { // cwg16: 2.8
       // expected-error@#cwg16-A-f-call {{'A' is a private member of 'cwg16::A'}}
       //   expected-note@#cwg16-B {{constrained by implicitly private inheritance here}}
       //   expected-note@#cwg16-A {{member is declared here}}
-      // expected-error@#cwg16-A-f-call {{cannot cast 'cwg16::C' to its private base class 'cwg16::A'}}
+      // expected-error@#cwg16-A-f-call {{cannot cast 'cwg16::C' to its private base class 'A'}}
       //   expected-note@#cwg16-B {{implicitly declared private here}}
     }
   };
@@ -838,7 +838,7 @@ namespace cwg52 { // cwg52: 2.8
   // expected-error@#cwg52-k {{'A' is a private member of 'cwg52::A'}}
   //   expected-note@#cwg52-B {{constrained by private inheritance here}}
   //   expected-note@#cwg52-A {{member is declared here}}
-  // expected-error@#cwg52-k {{cannot cast 'struct B' to its private base class 'cwg52::A'}}
+  // expected-error@#cwg52-k {{cannot cast 'struct B' to its private base class 'A'}}
   //   expected-note@#cwg52-B {{declared private here}}
 } // namespace cwg52
 
diff --git a/clang/test/CXX/drs/cwg3xx.cpp b/clang/test/CXX/drs/cwg3xx.cpp
index 5d09697a3cf6b..bbd87c060801a 100644
--- a/clang/test/CXX/drs/cwg3xx.cpp
+++ b/clang/test/CXX/drs/cwg3xx.cpp
@@ -1321,7 +1321,7 @@ namespace cwg381 { // cwg381: 2.7
   void f() {
     E e;
     e.B::a = 0;
-    /* expected-error at -1 {{ambiguous conversion from derived class 'E' to base class 'cwg381::B':
+    /* expected-error at -1 {{ambiguous conversion from derived class 'E' to base class 'B':
     struct cwg381::E -> C -> B
     struct cwg381::E -> D -> B}} */
     F f;
diff --git a/clang/test/ExtractAPI/class_template_param_inheritance.cpp b/clang/test/ExtractAPI/class_template_param_inheritance.cpp
index 53b331e0b460b..5f0056456c730 100644
--- a/clang/test/ExtractAPI/class_template_param_inheritance.cpp
+++ b/clang/test/ExtractAPI/class_template_param_inheritance.cpp
@@ -44,7 +44,7 @@ template<typename T> class Foo : public T {};
     {
       "kind": "inheritsFrom",
       "source": "c:@ST>1#T at Foo",
-      "target": "",
+      "target": "c:input.h at 9",
       "targetFallback": "T"
     }
   ],
diff --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp
index e1aa9bb0721f8..858423a06576a 100644
--- a/clang/tools/libclang/CIndex.cpp
+++ b/clang/tools/libclang/CIndex.cpp
@@ -1728,7 +1728,7 @@ bool CursorVisitor::VisitUsingTypeLoc(UsingTypeLoc TL) {
   if (VisitNestedNameSpecifierLoc(TL.getQualifierLoc()))
     return true;
 
-  auto *underlyingDecl = TL.getTypePtr()->desugar()->getAsTagDecl();
+  auto *underlyingDecl = TL.getTypePtr()->getAsTagDecl();
   if (underlyingDecl) {
     return Visit(MakeCursorTypeRef(underlyingDecl, TL.getNameLoc(), TU));
   }
diff --git a/clang/unittests/AST/ASTImporterTest.cpp b/clang/unittests/AST/ASTImporterTest.cpp
index 15df9af889386..5badbd7d65e48 100644
--- a/clang/unittests/AST/ASTImporterTest.cpp
+++ b/clang/unittests/AST/ASTImporterTest.cpp
@@ -9189,7 +9189,7 @@ struct ImportInjectedClassNameType : public ASTImporterOptionSpecificTestBase {
                          .getTypePtr();
     ASSERT_TRUE(Ty);
     EXPECT_FALSE(Ty->isCanonicalUnqualified());
-    const auto *InjTy = Ty->castAs<InjectedClassNameType>();
+    const auto *InjTy = dyn_cast<InjectedClassNameType>(Ty);
     EXPECT_TRUE(InjTy);
     for (const Decl *ReD : D->redecls()) {
       if (ReD == D)
@@ -9204,8 +9204,6 @@ struct ImportInjectedClassNameType : public ASTImporterOptionSpecificTestBase {
       EXPECT_FALSE(ReTy->isCanonicalUnqualified());
       EXPECT_NE(ReTy, Ty);
       EXPECT_TRUE(Ctx.hasSameType(ReTy, Ty));
-      const auto *ReInjTy = Ty->castAs<InjectedClassNameType>();
-      EXPECT_TRUE(ReInjTy);
     }
   }
 
diff --git a/clang/unittests/AST/RandstructTest.cpp b/clang/unittests/AST/RandstructTest.cpp
index cba446fc529e1..a90665b2c0631 100644
--- a/clang/unittests/AST/RandstructTest.cpp
+++ b/clang/unittests/AST/RandstructTest.cpp
@@ -531,8 +531,7 @@ TEST(RANDSTRUCT_TEST, AnonymousStructsAndUnionsRetainFieldOrder) {
 
   for (const Decl *D : RD->decls())
     if (const FieldDecl *FD = dyn_cast<FieldDecl>(D)) {
-      if (const auto *Record = FD->getType()->getAs<RecordType>()) {
-        RD = Record->getOriginalDecl()->getDefinitionOrSelf();
+      if (const auto *RD = FD->getType()->getAsRecordDecl()) {
         if (RD->isAnonymousStructOrUnion()) {
           // These field orders shouldn't change.
           if (RD->isUnion()) {
diff --git a/clang/utils/TableGen/ASTTableGen.h b/clang/utils/TableGen/ASTTableGen.h
index 02b97636cf5f2..e9a86f3d5edd3 100644
--- a/clang/utils/TableGen/ASTTableGen.h
+++ b/clang/utils/TableGen/ASTTableGen.h
@@ -38,7 +38,7 @@
 #define AlwaysDependentClassName "AlwaysDependent"
 #define NeverCanonicalClassName "NeverCanonical"
 #define NeverCanonicalUnlessDependentClassName "NeverCanonicalUnlessDependent"
-#define LeafTypeClassName "LeafType"
+#define AlwaysCanonicalTypeClassName "AlwaysCanonical"
 
 // Cases of various non-ASTNode structured types like DeclarationName.
 #define TypeKindClassName "PropertyTypeKind"
diff --git a/clang/utils/TableGen/ClangTypeNodesEmitter.cpp b/clang/utils/TableGen/ClangTypeNodesEmitter.cpp
index 37039361cfc22..2e1eaef6c1823 100644
--- a/clang/utils/TableGen/ClangTypeNodesEmitter.cpp
+++ b/clang/utils/TableGen/ClangTypeNodesEmitter.cpp
@@ -40,8 +40,9 @@
 // There is a sixth macro, independent of the others.  Most clients
 // will not need to use it.
 //
-//    LEAF_TYPE(Class) - A type that never has inner types.  Clients
-//    which can operate on such types more efficiently may wish to do so.
+//    ALWAYS_CANONICAL_TYPE(Class) - A type which is always identical to its
+//    canonical type.  Clients which can operate on such types more efficiently
+//    may wish to do so.
 //
 //===----------------------------------------------------------------------===//
 
@@ -66,7 +67,7 @@ using namespace clang::tblgen;
 #define NonCanonicalUnlessDependentTypeMacroName "NON_CANONICAL_UNLESS_DEPENDENT_TYPE"
 #define TypeMacroArgs "(Class, Base)"
 #define LastTypeMacroName "LAST_TYPE"
-#define LeafTypeMacroName "LEAF_TYPE"
+#define AlwaysCanonicalTypeMacroName "ALWAYS_CANONICAL_TYPE"
 
 #define TypeClassName "Type"
 
@@ -90,7 +91,7 @@ class TypeNodeEmitter {
 
   void emitNodeInvocations();
   void emitLastNodeInvocation(TypeNode lastType);
-  void emitLeafNodeInvocations();
+  void emitAlwaysCanonicalNodeInvocations();
 
   void addMacroToUndef(StringRef macroName);
   void emitUndefs();
@@ -109,12 +110,12 @@ void TypeNodeEmitter::emit() {
   emitFallbackDefine(AbstractTypeMacroName, TypeMacroName, TypeMacroArgs);
   emitFallbackDefine(NonCanonicalTypeMacroName, TypeMacroName, TypeMacroArgs);
   emitFallbackDefine(DependentTypeMacroName, TypeMacroName, TypeMacroArgs);
-  emitFallbackDefine(NonCanonicalUnlessDependentTypeMacroName, TypeMacroName, 
+  emitFallbackDefine(NonCanonicalUnlessDependentTypeMacroName, TypeMacroName,
                      TypeMacroArgs);
 
   // Invocations.
   emitNodeInvocations();
-  emitLeafNodeInvocations();
+  emitAlwaysCanonicalNodeInvocations();
 
   // Postmatter
   emitUndefs();
@@ -178,15 +179,16 @@ void TypeNodeEmitter::emitLastNodeInvocation(TypeNode type) {
          "#endif\n";
 }
 
-void TypeNodeEmitter::emitLeafNodeInvocations() {
-  Out << "#ifdef " LeafTypeMacroName "\n";
+void TypeNodeEmitter::emitAlwaysCanonicalNodeInvocations() {
+  Out << "#ifdef " AlwaysCanonicalTypeMacroName "\n";
 
   for (TypeNode type : Types) {
-    if (!type.isSubClassOf(LeafTypeClassName)) continue;
-    Out << LeafTypeMacroName "(" << type.getId() << ")\n";
+    if (!type.isSubClassOf(AlwaysCanonicalTypeClassName))
+      continue;
+    Out << AlwaysCanonicalTypeMacroName "(" << type.getId() << ")\n";
   }
 
-  Out << "#undef " LeafTypeMacroName "\n"
+  Out << "#undef " AlwaysCanonicalTypeMacroName "\n"
          "#endif\n";
 }
 



More information about the cfe-commits mailing list