[clang] [ObjC] Enable diagnose_if on Objective-C methods (PR #115056)

via cfe-commits cfe-commits at lists.llvm.org
Tue Nov 5 12:23:18 PST 2024


https://github.com/apple-fcloutier created https://github.com/llvm/llvm-project/pull/115056

This change enables checking argument-dependent `diagnose_if` diagnostics on Objective-C methods.

It changes EvaluateWithSubstitution to accept any NamedDecl, concretely expecting them to be either FunctionDecls or ObjCMethodDecls.

rdar://138000724

>From 802e8e85de3b3d0a8c6ccd4f7ac3536f10183ea2 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?F=C3=A9lix=20Cloutier?= <fcloutier at apple.com>
Date: Fri, 25 Oct 2024 14:41:28 -0700
Subject: [PATCH] [ObjC] Enable diagnose_if on Objective-C methods

This change enables checking argument-dependent diagnose_if diagnostics on
Objective-C methods.

It changes EvaluateWithSubstitution to accept any NamedDecl, concretely
expecting them to be either FunctionDecls or ObjCMethodDecls.

rdar://138000724
---
 clang/include/clang/AST/Expr.h    |   2 +-
 clang/include/clang/Sema/Sema.h   |   6 +-
 clang/lib/AST/ExprConstant.cpp    | 166 +++++++++++++++++++++++-------
 clang/lib/Sema/SemaDeclAttr.cpp   |   6 ++
 clang/lib/Sema/SemaExprObjC.cpp   |   8 ++
 clang/lib/Sema/SemaOverload.cpp   |   4 +-
 clang/test/SemaObjC/diagnose_if.m |  10 ++
 7 files changed, 161 insertions(+), 41 deletions(-)

diff --git a/clang/include/clang/AST/Expr.h b/clang/include/clang/AST/Expr.h
index 466c65a9685ad32..c684204ba03f113 100644
--- a/clang/include/clang/AST/Expr.h
+++ b/clang/include/clang/AST/Expr.h
@@ -741,7 +741,7 @@ class Expr : public ValueStmt {
   /// unevaluated context. Returns true if the expression could be folded to a
   /// constant.
   bool EvaluateWithSubstitution(APValue &Value, ASTContext &Ctx,
-                                const FunctionDecl *Callee,
+                                const NamedDecl *Callee,
                                 ArrayRef<const Expr*> Args,
                                 const Expr *This = nullptr) const;
 
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 93d98e1cbb9c811..62e05b7698307f4 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -10302,13 +10302,15 @@ class Sema final : public SemaBase {
                               bool MissingImplicitThis = false);
 
   /// Emit diagnostics for the diagnose_if attributes on Function, ignoring any
-  /// non-ArgDependent DiagnoseIfAttrs.
+  /// non-ArgDependent DiagnoseIfAttrs. Function should be a function, a
+  /// C++ method, or an Objective-C method. ThisArg should be non-NULL only for
+  /// C++ methods.
   ///
   /// Argument-dependent diagnose_if attributes should be checked each time a
   /// function is used as a direct callee of a function call.
   ///
   /// Returns true if any errors were emitted.
-  bool diagnoseArgDependentDiagnoseIfAttrs(const FunctionDecl *Function,
+  bool diagnoseArgDependentDiagnoseIfAttrs(const NamedDecl *Function,
                                            const Expr *ThisArg,
                                            ArrayRef<const Expr *> Args,
                                            SourceLocation Loc);
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index d664c503655ba6b..6a9a278ba226d91 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -515,10 +515,80 @@ namespace {
     Call
   };
 
+  /// Either a FunctionDecl or an ObjCMethodDecl. This struct papers over the
+  /// fact that their common ancestors are DeclContext and NamedDecl, which
+  /// does not allow the enumeration of their parameters very easily.
+  class CallableDecl {
+  public:
+    using param_const_iterator = const ParmVarDecl *const *;
+
+    CallableDecl(std::nullptr_t) : DC(nullptr) {}
+    CallableDecl(const ObjCMethodDecl *MD) : DC(MD) {}
+    CallableDecl(const FunctionDecl *MD) : DC(MD) {}
+    CallableDecl() : CallableDecl(nullptr) {}
+
+    operator bool() const {
+      return DC;
+    }
+
+    const NamedDecl *getAsNamedDecl() const {
+      if (auto Func = dyn_cast<FunctionDecl>(DC))
+        return Func;
+      return cast<ObjCMethodDecl>(DC);
+    }
+
+    const DeclContext *getAsDeclContext() const {
+      return DC;
+    }
+
+    const FunctionDecl *getAsFunctionDecl() const {
+      return dyn_cast_or_null<FunctionDecl>(DC);
+    }
+
+    const CXXMethodDecl *getAsCXXMethodDecl() const {
+      return dyn_cast_or_null<CXXMethodDecl>(DC);
+    }
+
+    const ObjCMethodDecl *getAsObjCMethodDecl() const {
+      return dyn_cast_or_null<ObjCMethodDecl>(DC);
+    }
+
+    unsigned getNumParams() const {
+      if (auto Func = dyn_cast_or_null<FunctionDecl>(DC))
+        return Func->getNumParams();
+      return cast<ObjCMethodDecl>(DC)->param_size();
+    }
+
+    const ParmVarDecl *getParamDecl(unsigned I) const {
+      if (auto Func = dyn_cast_or_null<FunctionDecl>(DC))
+        return Func->getParamDecl(I);
+      return cast<ObjCMethodDecl>(DC)->getParamDecl(I);
+    }
+
+    param_const_iterator param_begin() const {
+      if (auto Func = dyn_cast_or_null<FunctionDecl>(DC))
+        return Func->param_begin();
+      return cast<ObjCMethodDecl>(DC)->param_begin();
+    }
+
+    param_const_iterator param_end() const {
+      if (auto Func = dyn_cast_or_null<FunctionDecl>(DC))
+        return Func->param_end();
+      return cast<ObjCMethodDecl>(DC)->param_end();
+    }
+
+  private:
+    const DeclContext *DC;
+  };
+
+  inline bool operator!=(CallableDecl A, CallableDecl B) {
+    return A.getAsDeclContext() != B.getAsDeclContext();
+  }
+
   /// A reference to a particular call and its arguments.
   struct CallRef {
     CallRef() : OrigCallee(), CallIndex(0), Version() {}
-    CallRef(const FunctionDecl *Callee, unsigned CallIndex, unsigned Version)
+    CallRef(CallableDecl Callee, unsigned CallIndex, unsigned Version)
         : OrigCallee(Callee), CallIndex(CallIndex), Version(Version) {}
 
     explicit operator bool() const { return OrigCallee; }
@@ -526,15 +596,17 @@ namespace {
     /// Get the parameter that the caller initialized, corresponding to the
     /// given parameter in the callee.
     const ParmVarDecl *getOrigParam(const ParmVarDecl *PVD) const {
-      return OrigCallee ? OrigCallee->getParamDecl(PVD->getFunctionScopeIndex())
-                        : PVD;
+      if (OrigCallee != nullptr && PVD != nullptr) {
+        return OrigCallee.getParamDecl(PVD->getFunctionScopeIndex());
+      }
+      return PVD;
     }
 
     /// The callee at the point where the arguments were evaluated. This might
     /// be different from the actual callee (a different redeclaration, or a
     /// virtual override), but this function's parameters are the ones that
     /// appear in the parameter map.
-    const FunctionDecl *OrigCallee;
+    CallableDecl OrigCallee;
     /// The call index of the frame that holds the argument values.
     unsigned CallIndex;
     /// The version of the parameters corresponding to this call.
@@ -549,8 +621,9 @@ namespace {
     /// Parent - The caller of this stack frame.
     CallStackFrame *Caller;
 
-    /// Callee - The function which was called.
-    const FunctionDecl *Callee;
+    /// Callee - The function which was called. Wraps a Function or an
+    /// ObjCMethod.
+    CallableDecl Callee;
 
     /// This - The binding for the this pointer in this call, if any.
     const LValue *This;
@@ -598,6 +671,10 @@ namespace {
       return {Callee, Index, ++CurTempVersion};
     }
 
+    CallRef createCall(const ObjCMethodDecl *Callee) {
+      return {Callee, Index, ++CurTempVersion};
+    }
+
     // FIXME: Adding this to every 'CallStackFrame' may have a nontrivial impact
     // on the overall stack usage of deeply-recursing constexpr evaluations.
     // (We should cache this map rather than recomputing it repeatedly.)
@@ -610,8 +687,15 @@ namespace {
     FieldDecl *LambdaThisCaptureField = nullptr;
 
     CallStackFrame(EvalInfo &Info, SourceRange CallRange,
-                   const FunctionDecl *Callee, const LValue *This,
+                   CallableDecl Callable, const LValue *This,
                    const Expr *CallExpr, CallRef Arguments);
+
+    CallStackFrame(EvalInfo &Info, SourceRange CallRange,
+                   std::nullptr_t, const LValue *This,
+                   const Expr *CallExpr, CallRef Arguments)
+    : CallStackFrame(Info, CallRange, CallableDecl(), This, CallExpr, Arguments)
+    { }
+
     ~CallStackFrame();
 
     // Return the temporary for Key whose version number is Version.
@@ -654,10 +738,10 @@ namespace {
 
     Frame *getCaller() const override { return Caller; }
     SourceRange getCallRange() const override { return CallRange; }
-    const FunctionDecl *getCallee() const override { return Callee; }
+    const FunctionDecl *getCallee() const override { return Callee.getAsFunctionDecl(); }
 
     bool isStdFunction() const {
-      for (const DeclContext *DC = Callee; DC; DC = DC->getParent())
+      for (const DeclContext *DC = Callee.getAsDeclContext(); DC; DC = DC->getParent())
         if (DC->isStdNamespace())
           return true;
       return false;
@@ -1139,7 +1223,7 @@ namespace {
     StdAllocatorCaller getStdAllocatorCaller(StringRef FnName) const {
       for (const CallStackFrame *Call = CurrentCall; Call != &BottomFrame;
            Call = Call->Caller) {
-        const auto *MD = dyn_cast_or_null<CXXMethodDecl>(Call->Callee);
+        const auto *MD = Call->Callee.getAsCXXMethodDecl();
         if (!MD)
           continue;
         const IdentifierInfo *FnII = MD->getIdentifier();
@@ -1509,7 +1593,7 @@ void SubobjectDesignator::diagnosePointerArithmetic(EvalInfo &Info,
 }
 
 CallStackFrame::CallStackFrame(EvalInfo &Info, SourceRange CallRange,
-                               const FunctionDecl *Callee, const LValue *This,
+                               CallableDecl Callee, const LValue *This,
                                const Expr *CallExpr, CallRef Call)
     : Info(Info), Caller(Info.CurrentCall), Callee(Callee), This(This),
       CallExpr(CallExpr), Arguments(Call), CallRange(CallRange),
@@ -1995,12 +2079,13 @@ APValue *EvalInfo::createHeapAlloc(const Expr *E, QualType T, LValue &LV) {
 /// Produce a string describing the given constexpr call.
 void CallStackFrame::describe(raw_ostream &Out) const {
   unsigned ArgIndex = 0;
-  bool IsMemberCall =
-      isa<CXXMethodDecl>(Callee) && !isa<CXXConstructorDecl>(Callee) &&
-      cast<CXXMethodDecl>(Callee)->isImplicitObjectMemberFunction();
+  bool IsMemberCall = false;
+  const NamedDecl *ND = Callee.getAsNamedDecl();
+  if (auto MD = Callee.getAsCXXMethodDecl())
+    IsMemberCall = !isa<CXXConstructorDecl>(MD) && MD->isImplicitObjectMemberFunction();
 
   if (!IsMemberCall)
-    Callee->getNameForDiagnostic(Out, Info.Ctx.getPrintingPolicy(),
+    ND->getNameForDiagnostic(Out, Info.Ctx.getPrintingPolicy(),
                                  /*Qualified=*/false);
 
   if (This && IsMemberCall) {
@@ -2026,15 +2111,15 @@ void CallStackFrame::describe(raw_ostream &Out) const {
           Info.Ctx.getLValueReferenceType(This->Designator.MostDerivedType));
       Out << ".";
     }
-    Callee->getNameForDiagnostic(Out, Info.Ctx.getPrintingPolicy(),
+    ND->getNameForDiagnostic(Out, Info.Ctx.getPrintingPolicy(),
                                  /*Qualified=*/false);
     IsMemberCall = false;
   }
 
   Out << '(';
 
-  for (FunctionDecl::param_const_iterator I = Callee->param_begin(),
-       E = Callee->param_end(); I != E; ++I, ++ArgIndex) {
+  for (CallableDecl::param_const_iterator I = Callee.param_begin(),
+       E = Callee.param_end(); I != E; ++I, ++ArgIndex) {
     if (ArgIndex > (unsigned)IsMemberCall)
       Out << ", ";
 
@@ -2046,7 +2131,7 @@ void CallStackFrame::describe(raw_ostream &Out) const {
       Out << "<...>";
 
     if (ArgIndex == 0 && IsMemberCall)
-      Out << "->" << *Callee << '(';
+      Out << "->" << *ND << '(';
   }
 
   Out << ')';
@@ -2280,8 +2365,8 @@ static void NoteLValueLocation(EvalInfo &Info, APValue::LValueBase Base) {
     for (CallStackFrame *F = Info.CurrentCall; F; F = F->Caller) {
       if (F->Arguments.CallIndex == Base.getCallIndex() &&
           F->Arguments.Version == Base.getVersion() && F->Callee &&
-          Idx < F->Callee->getNumParams()) {
-        VD = F->Callee->getParamDecl(Idx);
+          Idx < F->Callee.getNumParams()) {
+        VD = F->Callee.getParamDecl(Idx);
         break;
       }
     }
@@ -3460,8 +3545,9 @@ static bool evaluateVarDeclInit(EvalInfo &Info, const Expr *E,
       // not declared within the call operator are captures and during checking
       // of a potential constant expression, assume they are unknown constant
       // expressions.
-      assert(isLambdaCallOperator(Frame->Callee) &&
-             (VD->getDeclContext() != Frame->Callee || VD->isInitCapture()) &&
+      const auto *FD = Frame->Callee.getAsFunctionDecl();
+      assert(isLambdaCallOperator(FD) &&
+             (VD->getDeclContext() != FD || VD->isInitCapture()) &&
              "missing value for local variable");
       if (Info.checkingPotentialConstantExpression())
         return false;
@@ -3486,7 +3572,7 @@ static bool evaluateVarDeclInit(EvalInfo &Info, const Expr *E,
     // constant expressions.
     if (!Info.checkingPotentialConstantExpression() ||
         !Info.CurrentCall->Callee ||
-        !Info.CurrentCall->Callee->Equals(VD->getDeclContext())) {
+        !Info.CurrentCall->Callee.getAsDeclContext()->Equals(VD->getDeclContext())) {
       if (Info.getLangOpts().CPlusPlus11) {
         Info.FFDiag(E, diag::note_constexpr_function_param_value_unknown)
             << VD;
@@ -8832,7 +8918,7 @@ bool LValueExprEvaluator::VisitVarDecl(const Expr *E, const VarDecl *VD) {
   // to within 'E' actually represents a lambda-capture that maps to a
   // data-member/field within the closure object, and if so, evaluate to the
   // field or what the field refers to.
-  if (Info.CurrentCall && isLambdaCallOperator(Info.CurrentCall->Callee) &&
+  if (Info.CurrentCall && isLambdaCallOperator(Info.CurrentCall->Callee.getAsFunctionDecl()) &&
       isa<DeclRefExpr>(E) &&
       cast<DeclRefExpr>(E)->refersToEnclosingVariableOrCapture()) {
     // We don't always have a complete capture-map when checking or inferring if
@@ -8843,7 +8929,7 @@ bool LValueExprEvaluator::VisitVarDecl(const Expr *E, const VarDecl *VD) {
       return false;
 
     if (auto *FD = Info.CurrentCall->LambdaCaptureFields.lookup(VD)) {
-      const auto *MD = cast<CXXMethodDecl>(Info.CurrentCall->Callee);
+      const auto *MD = Info.CurrentCall->Callee.getAsCXXMethodDecl();
       return HandleLambdaCapture(Info, E, Result, MD, FD,
                                  FD->getType()->isReferenceType());
     }
@@ -8859,7 +8945,7 @@ bool LValueExprEvaluator::VisitVarDecl(const Expr *E, const VarDecl *VD) {
     // variable) or be ill-formed (and trigger an appropriate evaluation
     // diagnostic)).
     CallStackFrame *CurrFrame = Info.CurrentCall;
-    if (CurrFrame->Callee && CurrFrame->Callee->Equals(VD->getDeclContext())) {
+    if (CurrFrame->Callee && CurrFrame->Callee.getAsDeclContext()->Equals(VD->getDeclContext())) {
       // Function parameters are stored in some caller's frame. (Usually the
       // immediate caller, but for an inherited constructor they may be more
       // distant.)
@@ -9383,7 +9469,7 @@ class PointerExprEvaluator
       return false;
 
     bool IsExplicitLambda =
-        isLambdaCallWithExplicitObjectParameter(Info.CurrentCall->Callee);
+        isLambdaCallWithExplicitObjectParameter(Info.CurrentCall->Callee.getAsFunctionDecl());
     if (!IsExplicitLambda) {
       if (!Info.CurrentCall->This) {
         DiagnoseInvalidUseOfThis();
@@ -9393,7 +9479,7 @@ class PointerExprEvaluator
       Result = *Info.CurrentCall->This;
     }
 
-    if (isLambdaCallOperator(Info.CurrentCall->Callee)) {
+    if (isLambdaCallOperator(Info.CurrentCall->Callee.getAsFunctionDecl())) {
       // Ensure we actually have captured 'this'. If something was wrong with
       // 'this' capture, the error would have been previously reported.
       // Otherwise we can be inside of a default initialization of an object
@@ -9407,7 +9493,7 @@ class PointerExprEvaluator
         return true;
       }
 
-      const auto *MD = cast<CXXMethodDecl>(Info.CurrentCall->Callee);
+      const auto *MD = Info.CurrentCall->Callee.getAsCXXMethodDecl();
       return HandleLambdaCapture(
           Info, E, Result, MD, Info.CurrentCall->LambdaThisCaptureField,
           Info.CurrentCall->LambdaThisCaptureField->getType()->isPointerType());
@@ -9539,7 +9625,7 @@ bool PointerExprEvaluator::VisitCastExpr(const CastExpr *E) {
       //    that back to `const __impl*` in its body.
       if (VoidPtrCastMaybeOK &&
           (Info.getStdAllocatorCaller("allocate") ||
-           IsDeclSourceLocationCurrent(Info.CurrentCall->Callee) ||
+           IsDeclSourceLocationCurrent(Info.CurrentCall->Callee.getAsFunctionDecl()) ||
            Info.getLangOpts().CPlusPlus26)) {
         // Permitted.
       } else {
@@ -17403,12 +17489,14 @@ bool Expr::isCXX11ConstantExpr(const ASTContext &Ctx, APValue *Result,
 }
 
 bool Expr::EvaluateWithSubstitution(APValue &Value, ASTContext &Ctx,
-                                    const FunctionDecl *Callee,
+                                    const NamedDecl *Callee,
                                     ArrayRef<const Expr*> Args,
                                     const Expr *This) const {
   assert(!isValueDependent() &&
          "Expression evaluator can't be called on a dependent expression.");
 
+  assert(Callee != nullptr || Args.empty() && "substitutions always fail when Callee is nullptr");
+
   llvm::TimeTraceScope TimeScope("EvaluateWithSubstitution", [&] {
     std::string Name;
     llvm::raw_string_ostream OS(Name);
@@ -17440,13 +17528,19 @@ bool Expr::EvaluateWithSubstitution(APValue &Value, ASTContext &Ctx,
     Info.EvalStatus.HasSideEffects = false;
   }
 
-  CallRef Call = Info.CurrentCall->createCall(Callee);
+  CallRef Call;
+  if (Callee != nullptr) {
+    if (auto Func = dyn_cast<FunctionDecl>(Callee))
+      Call = Info.CurrentCall->createCall(Func);
+    else if (auto Method = dyn_cast<ObjCMethodDecl>(Callee))
+      Call = Info.CurrentCall->createCall(Method);
+  }
   for (ArrayRef<const Expr*>::iterator I = Args.begin(), E = Args.end();
        I != E; ++I) {
     unsigned Idx = I - Args.begin();
-    if (Idx >= Callee->getNumParams())
+    if (Idx >= Call.OrigCallee.getNumParams())
       break;
-    const ParmVarDecl *PVD = Callee->getParamDecl(Idx);
+    const ParmVarDecl *PVD = Call.OrigCallee.getParamDecl(Idx);
     if ((*I)->isValueDependent() ||
         !EvaluateCallArg(PVD, *I, Call, Info) ||
         Info.EvalStatus.HasSideEffects) {
@@ -17466,7 +17560,7 @@ bool Expr::EvaluateWithSubstitution(APValue &Value, ASTContext &Ctx,
   Info.EvalStatus.HasSideEffects = false;
 
   // Build fake call to Callee.
-  CallStackFrame Frame(Info, Callee->getLocation(), Callee, ThisPtr, This,
+  CallStackFrame Frame(Info, Callee->getLocation(), Call.OrigCallee, ThisPtr, This,
                        Call);
   // FIXME: Missing ExprWithCleanups in enable_if conditions?
   FullExpressionRAII Scope(Info);
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index d05d326178e1b85..c820c3c2bb871b9 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -742,6 +742,10 @@ class ArgumentDependenceChecker
     Parms.insert(FD->param_begin(), FD->param_end());
   }
 
+  ArgumentDependenceChecker(const ObjCMethodDecl *MD) {
+    Parms.insert(MD->param_begin(), MD->param_end());
+  }
+
   bool referencesArgs(Expr *E) {
     Result = false;
     TraverseStmt(E);
@@ -866,6 +870,8 @@ static void handleDiagnoseIfAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
   bool ArgDependent = false;
   if (const auto *FD = dyn_cast<FunctionDecl>(D))
     ArgDependent = ArgumentDependenceChecker(FD).referencesArgs(Cond);
+  else if (const auto *MD = dyn_cast<ObjCMethodDecl>(D))
+    ArgDependent = ArgumentDependenceChecker(MD).referencesArgs(Cond);
   D->addAttr(::new (S.Context) DiagnoseIfAttr(
       S.Context, AL, Cond, Msg, DiagType, ArgDependent, cast<NamedDecl>(D)));
 }
diff --git a/clang/lib/Sema/SemaExprObjC.cpp b/clang/lib/Sema/SemaExprObjC.cpp
index 3fcbbb417ff1faa..a56b9026e236e49 100644
--- a/clang/lib/Sema/SemaExprObjC.cpp
+++ b/clang/lib/Sema/SemaExprObjC.cpp
@@ -2706,6 +2706,10 @@ ExprResult SemaObjC::BuildClassMessage(
         << Method->getDeclName();
   }
 
+  // Check any arg-dependent diagnose_if conditions;
+  if (Method)
+    SemaRef.diagnoseArgDependentDiagnoseIfAttrs(Method, nullptr, ArgsIn, RBracLoc);
+
   // Warn about explicit call of +initialize on its own class. But not on 'super'.
   if (Method && Method->getMethodFamily() == OMF_initialize) {
     if (!SuperLoc.isValid()) {
@@ -3239,6 +3243,10 @@ ExprResult SemaObjC::BuildInstanceMessage(
           diag::err_illegal_message_expr_incomplete_type))
     return ExprError();
 
+  // Check any arg-dependent diagnose_if conditions;
+  if (Method)
+    SemaRef.diagnoseArgDependentDiagnoseIfAttrs(Method, nullptr, ArgsIn, RBracLoc);
+
   // In ARC, forbid the user from sending messages to
   // retain/release/autorelease/dealloc/retainCount explicitly.
   if (getLangOpts().ObjCAutoRefCount) {
diff --git a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp
index 4aeceba128b29b7..43c307ac19cdfe7 100644
--- a/clang/lib/Sema/SemaOverload.cpp
+++ b/clang/lib/Sema/SemaOverload.cpp
@@ -7360,7 +7360,7 @@ static bool diagnoseDiagnoseIfAttrsWith(Sema &S, const NamedDecl *ND,
   return false;
 }
 
-bool Sema::diagnoseArgDependentDiagnoseIfAttrs(const FunctionDecl *Function,
+bool Sema::diagnoseArgDependentDiagnoseIfAttrs(const NamedDecl *Function,
                                                const Expr *ThisArg,
                                                ArrayRef<const Expr *> Args,
                                                SourceLocation Loc) {
@@ -7372,7 +7372,7 @@ bool Sema::diagnoseArgDependentDiagnoseIfAttrs(const FunctionDecl *Function,
         // EvaluateWithSubstitution only cares about the position of each
         // argument in the arg list, not the ParmVarDecl* it maps to.
         if (!DIA->getCond()->EvaluateWithSubstitution(
-                Result, Context, cast<FunctionDecl>(DIA->getParent()), Args, ThisArg))
+                Result, Context, DIA->getParent(), Args, ThisArg))
           return false;
         return Result.isInt() && Result.getInt().getBoolValue();
       });
diff --git a/clang/test/SemaObjC/diagnose_if.m b/clang/test/SemaObjC/diagnose_if.m
index 9f281e4252dfa5c..ab97ae1225306b4 100644
--- a/clang/test/SemaObjC/diagnose_if.m
+++ b/clang/test/SemaObjC/diagnose_if.m
@@ -6,11 +6,21 @@
 
 @interface I
 -(void)meth _diagnose_if(1, "don't use this", "warning"); // expected-note 1{{from 'diagnose_if'}}
+-(void)meth:(int *)x _diagnose_if(x == 0, "x is NULL", "warning"); // expected-note 1{{from 'diagnose_if'}}
++(void)meth:(int *)x _diagnose_if(x == 0, "x is NULL", "warning"); // expected-note 1{{from 'diagnose_if'}}
 @property (assign) id prop _diagnose_if(1, "don't use this", "warning"); // expected-note 2{{from 'diagnose_if'}}
 @end
 
 void test(I *i) {
   [i meth]; // expected-warning {{don't use this}}
+
+  int x;
+  [i meth:&x];
+  [i meth:0]; //expected-warning {{x is NULL}}
+
+  [I meth:&x];
+  [I meth:0]; //expected-warning {{x is NULL}}
+
   id o1 = i.prop; // expected-warning {{don't use this}}
   id o2 = [i prop]; // expected-warning {{don't use this}}
 }



More information about the cfe-commits mailing list