[clang] [clang][Interp] Handle virtual calls with covariant return types (PR #101218)

via cfe-commits cfe-commits at lists.llvm.org
Tue Jul 30 11:37:20 PDT 2024


Timm =?utf-8?q?Bäder?= <tbaeder at redhat.com>
Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/101218 at github.com>


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clang

Author: Timm Baeder (tbaederr)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/101218.diff


14 Files Affected:

- (modified) clang/include/clang/AST/Type.h (+5) 
- (modified) clang/include/clang/Sema/Overload.h (+1-1) 
- (modified) clang/lib/AST/ExprCXX.cpp (+1-1) 
- (modified) clang/lib/AST/Interp/Context.cpp (+1-1) 
- (modified) clang/lib/AST/Interp/Interp.h (+23-1) 
- (modified) clang/lib/AST/ItaniumMangle.cpp (+1-1) 
- (modified) clang/lib/AST/Type.cpp (+1-1) 
- (modified) clang/lib/Analysis/Consumed.cpp (+3-7) 
- (modified) clang/lib/Analysis/ThreadSafetyCommon.cpp (+1-1) 
- (modified) clang/lib/Sema/SemaDecl.cpp (+3-3) 
- (modified) clang/lib/Sema/SemaDeclAttr.cpp (+1-1) 
- (modified) clang/lib/Sema/SemaExprCXX.cpp (+1-1) 
- (modified) clang/lib/Sema/SemaTemplate.cpp (+1-1) 
- (modified) clang/test/AST/Interp/cxx2a.cpp (+21) 


``````````diff
diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index 89a74ff1fb285..dec51e032158e 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -2509,6 +2509,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
   bool isFunctionNoProtoType() const { return getAs<FunctionNoProtoType>(); }
   bool isFunctionProtoType() const { return getAs<FunctionProtoType>(); }
   bool isPointerType() const;
+  bool isPointerOrReferenceType() const;
   bool isSignableType() const;
   bool isAnyPointerType() const;   // Any C pointer or ObjC object pointer
   bool isCountAttributedType() const;
@@ -7996,6 +7997,10 @@ inline bool Type::isPointerType() const {
   return isa<PointerType>(CanonicalType);
 }
 
+inline bool Type::isPointerOrReferenceType() const {
+  return isPointerType() || isReferenceType();
+}
+
 inline bool Type::isAnyPointerType() const {
   return isPointerType() || isObjCObjectPointerType();
 }
diff --git a/clang/include/clang/Sema/Overload.h b/clang/include/clang/Sema/Overload.h
index 26ffe057c74a2..d6a6cee62a752 100644
--- a/clang/include/clang/Sema/Overload.h
+++ b/clang/include/clang/Sema/Overload.h
@@ -984,7 +984,7 @@ class Sema;
     unsigned getNumParams() const {
       if (IsSurrogate) {
         QualType STy = Surrogate->getConversionType();
-        while (STy->isPointerType() || STy->isReferenceType())
+        while (STy->isPointerOrReferenceType())
           STy = STy->getPointeeType();
         return STy->castAs<FunctionProtoType>()->getNumParams();
       }
diff --git a/clang/lib/AST/ExprCXX.cpp b/clang/lib/AST/ExprCXX.cpp
index e2c9643151126..6212989e21737 100644
--- a/clang/lib/AST/ExprCXX.cpp
+++ b/clang/lib/AST/ExprCXX.cpp
@@ -152,7 +152,7 @@ bool CXXTypeidExpr::isMostDerived(ASTContext &Context) const {
   const Expr *E = getExprOperand()->IgnoreParenNoopCasts(Context);
   if (const auto *DRE = dyn_cast<DeclRefExpr>(E)) {
     QualType Ty = DRE->getDecl()->getType();
-    if (!Ty->isPointerType() && !Ty->isReferenceType())
+    if (!Ty->isPointerOrReferenceType())
       return true;
   }
 
diff --git a/clang/lib/AST/Interp/Context.cpp b/clang/lib/AST/Interp/Context.cpp
index b5e992c5a9ac1..fd66e6fe8dcbc 100644
--- a/clang/lib/AST/Interp/Context.cpp
+++ b/clang/lib/AST/Interp/Context.cpp
@@ -176,7 +176,7 @@ std::optional<PrimType> Context::classify(QualType T) const {
       T->isFunctionType())
     return PT_FnPtr;
 
-  if (T->isReferenceType() || T->isPointerType() ||
+  if (T->isPointerOrReferenceType() ||
       T->isObjCObjectPointerType())
     return PT_Ptr;
 
diff --git a/clang/lib/AST/Interp/Interp.h b/clang/lib/AST/Interp/Interp.h
index d128988a480e1..63e9966b831db 100644
--- a/clang/lib/AST/Interp/Interp.h
+++ b/clang/lib/AST/Interp/Interp.h
@@ -2628,7 +2628,29 @@ inline bool CallVirt(InterpState &S, CodePtr OpPC, const Function *Func,
     }
   }
 
-  return Call(S, OpPC, Func, VarArgSize);
+  if (!Call(S, OpPC, Func, VarArgSize))
+    return false;
+
+  // Covariant return types. The return type of Overrider is a pointer
+  // or reference to a class type.
+  if (Overrider != InitialFunction &&
+      Overrider->getReturnType()->isPointerOrReferenceType() &&
+      InitialFunction->getReturnType()->isPointerOrReferenceType()) {
+    QualType OverriderPointeeType =
+        Overrider->getReturnType()->getPointeeType();
+    QualType InitialPointeeType =
+        InitialFunction->getReturnType()->getPointeeType();
+    // We've called Overrider above, but calling code expects us to return what
+    // InitialFunction returned. According to the rules for covariant return
+    // types, what InitialFunction returns needs to be a base class of what
+    // Overrider returns. So, we need to do an upcast here.
+    unsigned Offset = S.getContext().collectBaseOffset(
+        InitialPointeeType->getAsRecordDecl(),
+        OverriderPointeeType->getAsRecordDecl());
+    return GetPtrBasePop(S, OpPC, Offset);
+  }
+
+  return true;
 }
 
 inline bool CallBI(InterpState &S, CodePtr &PC, const Function *Func,
diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp
index d46d621d4c7d4..ead5da4e90f2f 100644
--- a/clang/lib/AST/ItaniumMangle.cpp
+++ b/clang/lib/AST/ItaniumMangle.cpp
@@ -6484,7 +6484,7 @@ void CXXNameMangler::mangleValueInTemplateArg(QualType T, const APValue &V,
 
   case APValue::LValue: {
     // Proposed in https://github.com/itanium-cxx-abi/cxx-abi/issues/47.
-    assert((T->isPointerType() || T->isReferenceType()) &&
+    assert((T->isPointerOrReferenceType()) &&
            "unexpected type for LValue template arg");
 
     if (V.isNullPointer()) {
diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp
index fdaab8e434593..0456b5f96b210 100644
--- a/clang/lib/AST/Type.cpp
+++ b/clang/lib/AST/Type.cpp
@@ -75,7 +75,7 @@ bool Qualifiers::isStrictSupersetOf(Qualifiers Other) const {
 const IdentifierInfo* QualType::getBaseTypeIdentifier() const {
   const Type* ty = getTypePtr();
   NamedDecl *ND = nullptr;
-  if (ty->isPointerType() || ty->isReferenceType())
+  if (ty->isPointerOrReferenceType())
     return ty->getPointeeType().getBaseTypeIdentifier();
   else if (ty->isRecordType())
     ND = ty->castAs<RecordType>()->getDecl();
diff --git a/clang/lib/Analysis/Consumed.cpp b/clang/lib/Analysis/Consumed.cpp
index d01c7f688e8b5..63c5943242944 100644
--- a/clang/lib/Analysis/Consumed.cpp
+++ b/clang/lib/Analysis/Consumed.cpp
@@ -141,7 +141,7 @@ static bool isCallableInState(const CallableWhenAttr *CWAttr,
 }
 
 static bool isConsumableType(const QualType &QT) {
-  if (QT->isPointerType() || QT->isReferenceType())
+  if (QT->isPointerOrReferenceType())
     return false;
 
   if (const CXXRecordDecl *RD = QT->getAsCXXRecordDecl())
@@ -151,7 +151,7 @@ static bool isConsumableType(const QualType &QT) {
 }
 
 static bool isAutoCastType(const QualType &QT) {
-  if (QT->isPointerType() || QT->isReferenceType())
+  if (QT->isPointerOrReferenceType())
     return false;
 
   if (const CXXRecordDecl *RD = QT->getAsCXXRecordDecl())
@@ -186,10 +186,6 @@ static bool isTestingFunction(const FunctionDecl *FunDecl) {
   return FunDecl->hasAttr<TestTypestateAttr>();
 }
 
-static bool isPointerOrRef(QualType ParamType) {
-  return ParamType->isPointerType() || ParamType->isReferenceType();
-}
-
 static ConsumedState mapConsumableAttrState(const QualType QT) {
   assert(isConsumableType(QT));
 
@@ -648,7 +644,7 @@ bool ConsumedStmtVisitor::handleCall(const CallExpr *Call, const Expr *ObjArg,
       setStateForVarOrTmp(StateMap, PInfo, mapReturnTypestateAttrState(RT));
     else if (isRValueRef(ParamType) || isConsumableType(ParamType))
       setStateForVarOrTmp(StateMap, PInfo, consumed::CS_Consumed);
-    else if (isPointerOrRef(ParamType) &&
+    else if (ParamType->isPointerOrReferenceType() &&
              (!ParamType->getPointeeType().isConstQualified() ||
               isSetOnReadPtrType(ParamType)))
       setStateForVarOrTmp(StateMap, PInfo, consumed::CS_Unknown);
diff --git a/clang/lib/Analysis/ThreadSafetyCommon.cpp b/clang/lib/Analysis/ThreadSafetyCommon.cpp
index 3e8c959ccee4f..cbcfefdc52549 100644
--- a/clang/lib/Analysis/ThreadSafetyCommon.cpp
+++ b/clang/lib/Analysis/ThreadSafetyCommon.cpp
@@ -97,7 +97,7 @@ static StringRef ClassifyDiagnostic(QualType VDT) {
     if (const auto *TD = TT->getDecl())
       if (const auto *CA = TD->getAttr<CapabilityAttr>())
         return ClassifyDiagnostic(CA);
-  } else if (VDT->isPointerType() || VDT->isReferenceType())
+  } else if (VDT->isPointerOrReferenceType())
     return ClassifyDiagnostic(VDT->getPointeeType());
 
   return "mutex";
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 694a754646f27..cac68dc1b22a1 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -5887,7 +5887,7 @@ Sema::GetNameFromUnqualifiedId(const UnqualifiedId &Name) {
 
 static QualType getCoreType(QualType Ty) {
   do {
-    if (Ty->isPointerType() || Ty->isReferenceType())
+    if (Ty->isPointerOrReferenceType())
       Ty = Ty->getPointeeType();
     else if (Ty->isArrayType())
       Ty = Ty->castAsArrayTypeUnsafe()->getElementType();
@@ -9339,7 +9339,7 @@ static OpenCLParamType getOpenCLKernelParameterType(Sema &S, QualType PT) {
   if (PT->isDependentType())
     return InvalidKernelParam;
 
-  if (PT->isPointerType() || PT->isReferenceType()) {
+  if (PT->isPointerOrReferenceType())
     QualType PointeeType = PT->getPointeeType();
     if (PointeeType.getAddressSpace() == LangAS::opencl_generic ||
         PointeeType.getAddressSpace() == LangAS::opencl_private ||
@@ -10789,7 +10789,7 @@ Sema::ActOnFunctionDeclarator(Scope *S, Declarator &D, DeclContext *DC,
     if (getLangOpts().getOpenCLCompatibleVersion() >= 200) {
       if(const PipeType *PipeTy = PT->getAs<PipeType>()) {
         QualType ElemTy = PipeTy->getElementType();
-          if (ElemTy->isReferenceType() || ElemTy->isPointerType()) {
+          if (ElemTy->isPointerOrReferenceType()) {
             Diag(Param->getTypeSpecStartLoc(), diag::err_reference_pipe_type );
             D.setInvalidType();
           }
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 98e3df9083516..9011fa547638e 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -5992,7 +5992,7 @@ static void handleMSAllocatorAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
   // Warn if the return type is not a pointer or reference type.
   if (auto *FD = dyn_cast<FunctionDecl>(D)) {
     QualType RetTy = FD->getReturnType();
-    if (!RetTy->isPointerType() && !RetTy->isReferenceType()) {
+    if (!RetTy->isPointerOrReferenceType()) {
       S.Diag(AL.getLoc(), diag::warn_declspec_allocator_nonpointer)
           << AL.getRange() << RetTy;
       return;
diff --git a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/SemaExprCXX.cpp
index 029969c1722b7..c5003d9ac0254 100644
--- a/clang/lib/Sema/SemaExprCXX.cpp
+++ b/clang/lib/Sema/SemaExprCXX.cpp
@@ -711,7 +711,7 @@ getUuidAttrOfType(Sema &SemaRef, QualType QT,
                   llvm::SmallSetVector<const UuidAttr *, 1> &UuidAttrs) {
   // Optionally remove one level of pointer, reference or array indirection.
   const Type *Ty = QT.getTypePtr();
-  if (QT->isPointerType() || QT->isReferenceType())
+  if (QT->isPointerOrReferenceType())
     Ty = QT->getPointeeType().getTypePtr();
   else if (QT->isArrayType())
     Ty = Ty->getBaseElementTypeUnsafe();
diff --git a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp
index 87b1f98bbe5ac..c22e329bef2b9 100644
--- a/clang/lib/Sema/SemaTemplate.cpp
+++ b/clang/lib/Sema/SemaTemplate.cpp
@@ -6719,7 +6719,7 @@ ExprResult Sema::CheckTemplateArgument(NonTypeTemplateParmDecl *Param,
       auto *VD = const_cast<ValueDecl *>(Base.dyn_cast<const ValueDecl *>());
       //   For a non-type template-parameter of pointer or reference type,
       //   the value of the constant expression shall not refer to
-      assert(ParamType->isPointerType() || ParamType->isReferenceType() ||
+      assert(ParamType->isPointerOrReferenceType() ||
              ParamType->isNullPtrType());
       // -- a temporary object
       // -- a string literal
diff --git a/clang/test/AST/Interp/cxx2a.cpp b/clang/test/AST/Interp/cxx2a.cpp
index 27d1aa1a27f75..ad021b30cfd3c 100644
--- a/clang/test/AST/Interp/cxx2a.cpp
+++ b/clang/test/AST/Interp/cxx2a.cpp
@@ -13,3 +13,24 @@ consteval int aConstevalFunction() { // both-error {{consteval function never pr
   return 0;
 }
 /// We're NOT calling the above function. The diagnostics should appear anyway.
+
+namespace Covariant {
+  struct A {
+    virtual constexpr char f() const { return 'Z'; }
+    char a = f();
+  };
+
+  struct D : A {};
+  struct Covariant1 {
+    D d;
+    virtual const A *f() const;
+  };
+
+  struct Covariant3 : Covariant1 {
+    constexpr virtual const D *f() const { return &this->d; }
+  };
+
+  constexpr Covariant3 cb;
+  constexpr const Covariant1 *cb1 = &cb;
+  static_assert(cb1->f()->a == 'Z');
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/101218


More information about the cfe-commits mailing list