[clang] [llvm] Reland "[analyzer][NFC] Reorganize Z3 report refutation" (PR #97265)

Balazs Benics via cfe-commits cfe-commits at lists.llvm.org
Mon Jul 1 01:53:47 PDT 2024


https://github.com/steakhal created https://github.com/llvm/llvm-project/pull/97265

This is exactly as originally landed in #95128,
but now the minimal Z3 version was increased to meet this change in #96682.
https://discourse.llvm.org/t/bump-minimal-z3-requirements-from-4-7-1-to-4-8-9/79664/4

---

This change keeps existing behavior, namely that if we hit a Z3 timeout we will accept the report as "satisfiable".

This prepares for the commit "Harden safeguards for Z3 query times". https://discourse.llvm.org/t/analyzer-rfc-taming-z3-query-times/79520

(cherry picked from commit 89c26f6c7b0a6dfa257ec090fcf5b6e6e0c89aab)

>From 3d5acf85a3e67f76b127e04b47b4dd2335c5acf8 Mon Sep 17 00:00:00 2001
From: Balazs Benics <benicsbalazs at gmail.com>
Date: Mon, 1 Jul 2024 10:51:11 +0200
Subject: [PATCH] Reland "[analyzer][NFC] Reorganize Z3 report refutation"

This is exactly as originally landed in #95128,
but now the minimal Z3 version was increased to meet this change in #96682.
https://discourse.llvm.org/t/bump-minimal-z3-requirements-from-4-7-1-to-4-8-9/79664/4

---

This change keeps existing behavior, namely that if we hit a Z3 timeout
we will accept the report as "satisfiable".

This prepares for the commit "Harden safeguards for Z3 query times".
https://discourse.llvm.org/t/analyzer-rfc-taming-z3-query-times/79520

(cherry picked from commit 89c26f6c7b0a6dfa257ec090fcf5b6e6e0c89aab)
---
 .../Core/BugReporter/BugReporterVisitors.h    |  23 ----
 .../Core/BugReporter/Z3CrosscheckVisitor.h    |  66 ++++++++++
 .../Core/PathSensitive/SMTConstraintManager.h |   5 +-
 clang/lib/StaticAnalyzer/Core/BugReporter.cpp |  28 ++++-
 .../Core/BugReporterVisitors.cpp              |  76 -----------
 clang/lib/StaticAnalyzer/Core/CMakeLists.txt  |   1 +
 .../Core/Z3CrosscheckVisitor.cpp              | 118 ++++++++++++++++++
 .../test/Analysis/z3/crosscheck-statistics.c  |  33 +++++
 clang/unittests/StaticAnalyzer/CMakeLists.txt |   1 +
 .../StaticAnalyzer/Z3CrosscheckOracleTest.cpp |  59 +++++++++
 llvm/include/llvm/Support/SMTAPI.h            |  19 +++
 llvm/lib/Support/Z3Solver.cpp                 | 116 ++++++++++++++---
 12 files changed, 419 insertions(+), 126 deletions(-)
 create mode 100644 clang/include/clang/StaticAnalyzer/Core/BugReporter/Z3CrosscheckVisitor.h
 create mode 100644 clang/lib/StaticAnalyzer/Core/Z3CrosscheckVisitor.cpp
 create mode 100644 clang/test/Analysis/z3/crosscheck-statistics.c
 create mode 100644 clang/unittests/StaticAnalyzer/Z3CrosscheckOracleTest.cpp

diff --git a/clang/include/clang/StaticAnalyzer/Core/BugReporter/BugReporterVisitors.h b/clang/include/clang/StaticAnalyzer/Core/BugReporter/BugReporterVisitors.h
index cc3d93aabafda..f97514955a591 100644
--- a/clang/include/clang/StaticAnalyzer/Core/BugReporter/BugReporterVisitors.h
+++ b/clang/include/clang/StaticAnalyzer/Core/BugReporter/BugReporterVisitors.h
@@ -597,29 +597,6 @@ class SuppressInlineDefensiveChecksVisitor final : public BugReporterVisitor {
                                    PathSensitiveBugReport &BR) override;
 };
 
-/// The bug visitor will walk all the nodes in a path and collect all the
-/// constraints. When it reaches the root node, will create a refutation
-/// manager and check if the constraints are satisfiable
-class FalsePositiveRefutationBRVisitor final : public BugReporterVisitor {
-private:
-  /// Holds the constraints in a given path
-  ConstraintMap Constraints;
-
-public:
-  FalsePositiveRefutationBRVisitor();
-
-  void Profile(llvm::FoldingSetNodeID &ID) const override;
-
-  PathDiagnosticPieceRef VisitNode(const ExplodedNode *N,
-                                   BugReporterContext &BRC,
-                                   PathSensitiveBugReport &BR) override;
-
-  void finalizeVisitor(BugReporterContext &BRC, const ExplodedNode *EndPathNode,
-                       PathSensitiveBugReport &BR) override;
-  void addConstraints(const ExplodedNode *N,
-                      bool OverwriteConstraintsOnExistingSyms);
-};
-
 /// The visitor detects NoteTags and displays the event notes they contain.
 class TagVisitor : public BugReporterVisitor {
 public:
diff --git a/clang/include/clang/StaticAnalyzer/Core/BugReporter/Z3CrosscheckVisitor.h b/clang/include/clang/StaticAnalyzer/Core/BugReporter/Z3CrosscheckVisitor.h
new file mode 100644
index 0000000000000..9413fd739f607
--- /dev/null
+++ b/clang/include/clang/StaticAnalyzer/Core/BugReporter/Z3CrosscheckVisitor.h
@@ -0,0 +1,66 @@
+//===- Z3CrosscheckVisitor.h - Crosscheck reports with Z3 -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+//  This file defines the visitor and utilities around it for Z3 report
+//  refutation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_STATICANALYZER_CORE_BUGREPORTER_Z3CROSSCHECKVISITOR_H
+#define LLVM_CLANG_STATICANALYZER_CORE_BUGREPORTER_Z3CROSSCHECKVISITOR_H
+
+#include "clang/StaticAnalyzer/Core/BugReporter/BugReporterVisitors.h"
+
+namespace clang::ento {
+
+/// The bug visitor will walk all the nodes in a path and collect all the
+/// constraints. When it reaches the root node, will create a refutation
+/// manager and check if the constraints are satisfiable.
+class Z3CrosscheckVisitor final : public BugReporterVisitor {
+public:
+  struct Z3Result {
+    std::optional<bool> IsSAT = std::nullopt;
+  };
+  explicit Z3CrosscheckVisitor(Z3CrosscheckVisitor::Z3Result &Result);
+
+  void Profile(llvm::FoldingSetNodeID &ID) const override;
+
+  PathDiagnosticPieceRef VisitNode(const ExplodedNode *N,
+                                   BugReporterContext &BRC,
+                                   PathSensitiveBugReport &BR) override;
+
+  void finalizeVisitor(BugReporterContext &BRC, const ExplodedNode *EndPathNode,
+                       PathSensitiveBugReport &BR) override;
+
+private:
+  void addConstraints(const ExplodedNode *N,
+                      bool OverwriteConstraintsOnExistingSyms);
+
+  /// Holds the constraints in a given path.
+  ConstraintMap Constraints;
+  Z3Result &Result;
+};
+
+/// The oracle will decide if a report should be accepted or rejected based on
+/// the results of the Z3 solver.
+class Z3CrosscheckOracle {
+public:
+  enum Z3Decision {
+    AcceptReport, // The report was SAT.
+    RejectReport, // The report was UNSAT or UNDEF.
+  };
+
+  /// Makes a decision for accepting or rejecting the report based on the
+  /// result of the corresponding Z3 query.
+  static Z3Decision
+  interpretQueryResult(const Z3CrosscheckVisitor::Z3Result &Query);
+};
+
+} // namespace clang::ento
+
+#endif // LLVM_CLANG_STATICANALYZER_CORE_BUGREPORTER_Z3CROSSCHECKVISITOR_H
diff --git a/clang/include/clang/StaticAnalyzer/Core/PathSensitive/SMTConstraintManager.h b/clang/include/clang/StaticAnalyzer/Core/PathSensitive/SMTConstraintManager.h
index 5116a4c06850d..bf18c353b8508 100644
--- a/clang/include/clang/StaticAnalyzer/Core/PathSensitive/SMTConstraintManager.h
+++ b/clang/include/clang/StaticAnalyzer/Core/PathSensitive/SMTConstraintManager.h
@@ -34,7 +34,10 @@ class SMTConstraintManager : public clang::ento::SimpleConstraintManager {
 public:
   SMTConstraintManager(clang::ento::ExprEngine *EE,
                        clang::ento::SValBuilder &SB)
-      : SimpleConstraintManager(EE, SB) {}
+      : SimpleConstraintManager(EE, SB) {
+    Solver->setBoolParam("model", true); // Enable model finding
+    Solver->setUnsignedParam("timeout", 15000 /*milliseconds*/);
+  }
   virtual ~SMTConstraintManager() = default;
 
   //===------------------------------------------------------------------===//
diff --git a/clang/lib/StaticAnalyzer/Core/BugReporter.cpp b/clang/lib/StaticAnalyzer/Core/BugReporter.cpp
index 14ca507a16d55..c9a7fd0e035c2 100644
--- a/clang/lib/StaticAnalyzer/Core/BugReporter.cpp
+++ b/clang/lib/StaticAnalyzer/Core/BugReporter.cpp
@@ -35,6 +35,7 @@
 #include "clang/StaticAnalyzer/Core/AnalyzerOptions.h"
 #include "clang/StaticAnalyzer/Core/BugReporter/BugReporterVisitors.h"
 #include "clang/StaticAnalyzer/Core/BugReporter/BugType.h"
+#include "clang/StaticAnalyzer/Core/BugReporter/Z3CrosscheckVisitor.h"
 #include "clang/StaticAnalyzer/Core/Checker.h"
 #include "clang/StaticAnalyzer/Core/CheckerManager.h"
 #include "clang/StaticAnalyzer/Core/CheckerRegistryData.h"
@@ -86,6 +87,11 @@ STATISTIC(MaxValidBugClassSize,
           "The maximum number of bug reports in the same equivalence class "
           "where at least one report is valid (not suppressed)");
 
+STATISTIC(NumTimesReportPassesZ3, "Number of reports passed Z3");
+STATISTIC(NumTimesReportRefuted, "Number of reports refuted by Z3");
+STATISTIC(NumTimesReportEQClassWasExhausted,
+          "Number of times all reports of an equivalence class was refuted");
+
 BugReporterVisitor::~BugReporterVisitor() = default;
 
 void BugReporterContext::anchor() {}
@@ -2864,21 +2870,31 @@ std::optional<PathDiagnosticBuilder> PathDiagnosticBuilder::findValidReport(
         // If crosscheck is enabled, remove all visitors, add the refutation
         // visitor and check again
         R->clearVisitors();
-        R->addVisitor<FalsePositiveRefutationBRVisitor>();
+        Z3CrosscheckVisitor::Z3Result CrosscheckResult;
+        R->addVisitor<Z3CrosscheckVisitor>(CrosscheckResult);
 
         // We don't overwrite the notes inserted by other visitors because the
         // refutation manager does not add any new note to the path
         generateVisitorsDiagnostics(R, BugPath->ErrorNode, BRC);
+        switch (Z3CrosscheckOracle::interpretQueryResult(CrosscheckResult)) {
+        case Z3CrosscheckOracle::RejectReport:
+          ++NumTimesReportRefuted;
+          R->markInvalid("Infeasible constraints", /*Data=*/nullptr);
+          continue;
+        case Z3CrosscheckOracle::AcceptReport:
+          ++NumTimesReportPassesZ3;
+          break;
+        }
       }
 
-      // Check if the bug is still valid
-      if (R->isValid())
-        return PathDiagnosticBuilder(
-            std::move(BRC), std::move(BugPath->BugPath), BugPath->Report,
-            BugPath->ErrorNode, std::move(visitorNotes));
+      assert(R->isValid());
+      return PathDiagnosticBuilder(std::move(BRC), std::move(BugPath->BugPath),
+                                   BugPath->Report, BugPath->ErrorNode,
+                                   std::move(visitorNotes));
     }
   }
 
+  ++NumTimesReportEQClassWasExhausted;
   return {};
 }
 
diff --git a/clang/lib/StaticAnalyzer/Core/BugReporterVisitors.cpp b/clang/lib/StaticAnalyzer/Core/BugReporterVisitors.cpp
index 4ff4f7de425ca..7102bf51a57e8 100644
--- a/clang/lib/StaticAnalyzer/Core/BugReporterVisitors.cpp
+++ b/clang/lib/StaticAnalyzer/Core/BugReporterVisitors.cpp
@@ -3447,82 +3447,6 @@ UndefOrNullArgVisitor::VisitNode(const ExplodedNode *N, BugReporterContext &BRC,
   return nullptr;
 }
 
-//===----------------------------------------------------------------------===//
-// Implementation of FalsePositiveRefutationBRVisitor.
-//===----------------------------------------------------------------------===//
-
-FalsePositiveRefutationBRVisitor::FalsePositiveRefutationBRVisitor()
-    : Constraints(ConstraintMap::Factory().getEmptyMap()) {}
-
-void FalsePositiveRefutationBRVisitor::finalizeVisitor(
-    BugReporterContext &BRC, const ExplodedNode *EndPathNode,
-    PathSensitiveBugReport &BR) {
-  // Collect new constraints
-  addConstraints(EndPathNode, /*OverwriteConstraintsOnExistingSyms=*/true);
-
-  // Create a refutation manager
-  llvm::SMTSolverRef RefutationSolver = llvm::CreateZ3Solver();
-  ASTContext &Ctx = BRC.getASTContext();
-
-  // Add constraints to the solver
-  for (const auto &I : Constraints) {
-    const SymbolRef Sym = I.first;
-    auto RangeIt = I.second.begin();
-
-    llvm::SMTExprRef SMTConstraints = SMTConv::getRangeExpr(
-        RefutationSolver, Ctx, Sym, RangeIt->From(), RangeIt->To(),
-        /*InRange=*/true);
-    while ((++RangeIt) != I.second.end()) {
-      SMTConstraints = RefutationSolver->mkOr(
-          SMTConstraints, SMTConv::getRangeExpr(RefutationSolver, Ctx, Sym,
-                                                RangeIt->From(), RangeIt->To(),
-                                                /*InRange=*/true));
-    }
-
-    RefutationSolver->addConstraint(SMTConstraints);
-  }
-
-  // And check for satisfiability
-  std::optional<bool> IsSAT = RefutationSolver->check();
-  if (!IsSAT)
-    return;
-
-  if (!*IsSAT)
-    BR.markInvalid("Infeasible constraints", EndPathNode->getLocationContext());
-}
-
-void FalsePositiveRefutationBRVisitor::addConstraints(
-    const ExplodedNode *N, bool OverwriteConstraintsOnExistingSyms) {
-  // Collect new constraints
-  ConstraintMap NewCs = getConstraintMap(N->getState());
-  ConstraintMap::Factory &CF = N->getState()->get_context<ConstraintMap>();
-
-  // Add constraints if we don't have them yet
-  for (auto const &C : NewCs) {
-    const SymbolRef &Sym = C.first;
-    if (!Constraints.contains(Sym)) {
-      // This symbol is new, just add the constraint.
-      Constraints = CF.add(Constraints, Sym, C.second);
-    } else if (OverwriteConstraintsOnExistingSyms) {
-      // Overwrite the associated constraint of the Symbol.
-      Constraints = CF.remove(Constraints, Sym);
-      Constraints = CF.add(Constraints, Sym, C.second);
-    }
-  }
-}
-
-PathDiagnosticPieceRef FalsePositiveRefutationBRVisitor::VisitNode(
-    const ExplodedNode *N, BugReporterContext &, PathSensitiveBugReport &) {
-  addConstraints(N, /*OverwriteConstraintsOnExistingSyms=*/false);
-  return nullptr;
-}
-
-void FalsePositiveRefutationBRVisitor::Profile(
-    llvm::FoldingSetNodeID &ID) const {
-  static int Tag = 0;
-  ID.AddPointer(&Tag);
-}
-
 //===----------------------------------------------------------------------===//
 // Implementation of TagVisitor.
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/StaticAnalyzer/Core/CMakeLists.txt b/clang/lib/StaticAnalyzer/Core/CMakeLists.txt
index 8672876c0608d..fb9394a519eb7 100644
--- a/clang/lib/StaticAnalyzer/Core/CMakeLists.txt
+++ b/clang/lib/StaticAnalyzer/Core/CMakeLists.txt
@@ -51,6 +51,7 @@ add_clang_library(clangStaticAnalyzerCore
   SymbolManager.cpp
   TextDiagnostics.cpp
   WorkList.cpp
+  Z3CrosscheckVisitor.cpp
 
   LINK_LIBS
   clangAST
diff --git a/clang/lib/StaticAnalyzer/Core/Z3CrosscheckVisitor.cpp b/clang/lib/StaticAnalyzer/Core/Z3CrosscheckVisitor.cpp
new file mode 100644
index 0000000000000..a7db44ef8ea30
--- /dev/null
+++ b/clang/lib/StaticAnalyzer/Core/Z3CrosscheckVisitor.cpp
@@ -0,0 +1,118 @@
+//===- Z3CrosscheckVisitor.cpp - Crosscheck reports with Z3 -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+//  This file declares the visitor and utilities around it for Z3 report
+//  refutation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/StaticAnalyzer/Core/BugReporter/Z3CrosscheckVisitor.h"
+#include "clang/StaticAnalyzer/Core/BugReporter/BugReporter.h"
+#include "clang/StaticAnalyzer/Core/PathSensitive/SMTConv.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/Support/SMTAPI.h"
+
+#define DEBUG_TYPE "Z3CrosscheckOracle"
+
+STATISTIC(NumZ3QueriesDone, "Number of Z3 queries done");
+STATISTIC(NumTimesZ3TimedOut, "Number of times Z3 query timed out");
+
+STATISTIC(NumTimesZ3QueryAcceptsReport,
+          "Number of Z3 queries accepting a report");
+STATISTIC(NumTimesZ3QueryRejectReport,
+          "Number of Z3 queries rejecting a report");
+
+using namespace clang;
+using namespace ento;
+
+Z3CrosscheckVisitor::Z3CrosscheckVisitor(Z3CrosscheckVisitor::Z3Result &Result)
+    : Constraints(ConstraintMap::Factory().getEmptyMap()), Result(Result) {}
+
+void Z3CrosscheckVisitor::finalizeVisitor(BugReporterContext &BRC,
+                                          const ExplodedNode *EndPathNode,
+                                          PathSensitiveBugReport &BR) {
+  // Collect new constraints
+  addConstraints(EndPathNode, /*OverwriteConstraintsOnExistingSyms=*/true);
+
+  // Create a refutation manager
+  llvm::SMTSolverRef RefutationSolver = llvm::CreateZ3Solver();
+  RefutationSolver->setBoolParam("model", true);        // Enable model finding
+  RefutationSolver->setUnsignedParam("timeout", 15000); // ms
+
+  ASTContext &Ctx = BRC.getASTContext();
+
+  // Add constraints to the solver
+  for (const auto &[Sym, Range] : Constraints) {
+    auto RangeIt = Range.begin();
+
+    llvm::SMTExprRef SMTConstraints = SMTConv::getRangeExpr(
+        RefutationSolver, Ctx, Sym, RangeIt->From(), RangeIt->To(),
+        /*InRange=*/true);
+    while ((++RangeIt) != Range.end()) {
+      SMTConstraints = RefutationSolver->mkOr(
+          SMTConstraints, SMTConv::getRangeExpr(RefutationSolver, Ctx, Sym,
+                                                RangeIt->From(), RangeIt->To(),
+                                                /*InRange=*/true));
+    }
+    RefutationSolver->addConstraint(SMTConstraints);
+  }
+
+  // And check for satisfiability
+  std::optional<bool> IsSAT = RefutationSolver->check();
+  Result = Z3Result{IsSAT};
+}
+
+void Z3CrosscheckVisitor::addConstraints(
+    const ExplodedNode *N, bool OverwriteConstraintsOnExistingSyms) {
+  // Collect new constraints
+  ConstraintMap NewCs = getConstraintMap(N->getState());
+  ConstraintMap::Factory &CF = N->getState()->get_context<ConstraintMap>();
+
+  // Add constraints if we don't have them yet
+  for (auto const &[Sym, Range] : NewCs) {
+    if (!Constraints.contains(Sym)) {
+      // This symbol is new, just add the constraint.
+      Constraints = CF.add(Constraints, Sym, Range);
+    } else if (OverwriteConstraintsOnExistingSyms) {
+      // Overwrite the associated constraint of the Symbol.
+      Constraints = CF.remove(Constraints, Sym);
+      Constraints = CF.add(Constraints, Sym, Range);
+    }
+  }
+}
+
+PathDiagnosticPieceRef
+Z3CrosscheckVisitor::VisitNode(const ExplodedNode *N, BugReporterContext &,
+                               PathSensitiveBugReport &) {
+  addConstraints(N, /*OverwriteConstraintsOnExistingSyms=*/false);
+  return nullptr;
+}
+
+void Z3CrosscheckVisitor::Profile(llvm::FoldingSetNodeID &ID) const {
+  static int Tag = 0;
+  ID.AddPointer(&Tag);
+}
+
+Z3CrosscheckOracle::Z3Decision Z3CrosscheckOracle::interpretQueryResult(
+    const Z3CrosscheckVisitor::Z3Result &Query) {
+  ++NumZ3QueriesDone;
+
+  if (!Query.IsSAT.has_value()) {
+    // For backward compatibility, let's accept the first timeout.
+    ++NumTimesZ3TimedOut;
+    return AcceptReport;
+  }
+
+  if (Query.IsSAT.value()) {
+    ++NumTimesZ3QueryAcceptsReport;
+    return AcceptReport; // sat
+  }
+
+  ++NumTimesZ3QueryRejectReport;
+  return RejectReport; // unsat
+}
diff --git a/clang/test/Analysis/z3/crosscheck-statistics.c b/clang/test/Analysis/z3/crosscheck-statistics.c
new file mode 100644
index 0000000000000..7192824c5be31
--- /dev/null
+++ b/clang/test/Analysis/z3/crosscheck-statistics.c
@@ -0,0 +1,33 @@
+// RUN: %clang_analyze_cc1 -analyzer-checker=core -verify %s  \
+// RUN:   -analyzer-config crosscheck-with-z3=true \
+// RUN:   -analyzer-stats 2>&1 | FileCheck %s
+
+// REQUIRES: z3
+
+// expected-error at 1 {{Z3 refutation rate:1/2}}
+
+int accepting(int n) {
+  if (n == 4) {
+    n = n / (n-4); // expected-warning {{Division by zero}}
+  }
+  return n;
+}
+
+int rejecting(int n, int x) {
+  // Let's make the path infeasible.
+  if (2 < x && x < 5 && x*x == x*x*x) {
+    // Have the same condition as in 'accepting'.
+    if (n == 4) {
+      n = x / (n-4); // no-warning: refuted
+    }
+  }
+  return n;
+}
+
+// CHECK:       1 BugReporter         - Number of times all reports of an equivalence class was refuted
+// CHECK-NEXT:  1 BugReporter         - Number of reports passed Z3
+// CHECK-NEXT:  1 BugReporter         - Number of reports refuted by Z3
+
+// CHECK:       1 Z3CrosscheckVisitor - Number of Z3 queries accepting a report
+// CHECK-NEXT:  1 Z3CrosscheckVisitor - Number of Z3 queries rejecting a report
+// CHECK-NEXT:  2 Z3CrosscheckVisitor - Number of Z3 queries done
diff --git a/clang/unittests/StaticAnalyzer/CMakeLists.txt b/clang/unittests/StaticAnalyzer/CMakeLists.txt
index ff34d5747cc81..dcc557b44fb31 100644
--- a/clang/unittests/StaticAnalyzer/CMakeLists.txt
+++ b/clang/unittests/StaticAnalyzer/CMakeLists.txt
@@ -21,6 +21,7 @@ add_clang_unittest(StaticAnalysisTests
   SymbolReaperTest.cpp
   SValTest.cpp
   TestReturnValueUnderConstruction.cpp
+  Z3CrosscheckOracleTest.cpp
   )
 
 clang_target_link_libraries(StaticAnalysisTests
diff --git a/clang/unittests/StaticAnalyzer/Z3CrosscheckOracleTest.cpp b/clang/unittests/StaticAnalyzer/Z3CrosscheckOracleTest.cpp
new file mode 100644
index 0000000000000..efad4dd3f03b9
--- /dev/null
+++ b/clang/unittests/StaticAnalyzer/Z3CrosscheckOracleTest.cpp
@@ -0,0 +1,59 @@
+//===- unittests/StaticAnalyzer/Z3CrosscheckOracleTest.cpp ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/StaticAnalyzer/Core/BugReporter/Z3CrosscheckVisitor.h"
+#include "gtest/gtest.h"
+
+using namespace clang;
+using namespace ento;
+
+using Z3Result = Z3CrosscheckVisitor::Z3Result;
+using Z3Decision = Z3CrosscheckOracle::Z3Decision;
+
+static constexpr Z3Decision AcceptReport = Z3Decision::AcceptReport;
+static constexpr Z3Decision RejectReport = Z3Decision::RejectReport;
+
+static constexpr std::optional<bool> SAT = true;
+static constexpr std::optional<bool> UNSAT = false;
+static constexpr std::optional<bool> UNDEF = std::nullopt;
+
+namespace {
+
+struct Z3CrosscheckOracleTest : public testing::Test {
+  Z3Decision interpretQueryResult(const Z3Result &Result) const {
+    return Z3CrosscheckOracle::interpretQueryResult(Result);
+  }
+};
+
+TEST_F(Z3CrosscheckOracleTest, AcceptsFirstSAT) {
+  ASSERT_EQ(AcceptReport, interpretQueryResult({SAT}));
+}
+
+TEST_F(Z3CrosscheckOracleTest, AcceptsSAT) {
+  ASSERT_EQ(RejectReport, interpretQueryResult({UNSAT}));
+  ASSERT_EQ(AcceptReport, interpretQueryResult({SAT}));
+}
+
+TEST_F(Z3CrosscheckOracleTest, AcceptsFirstTimeout) {
+  ASSERT_EQ(AcceptReport, interpretQueryResult({UNDEF}));
+}
+
+TEST_F(Z3CrosscheckOracleTest, AcceptsTimeout) {
+  ASSERT_EQ(RejectReport, interpretQueryResult({UNSAT}));
+  ASSERT_EQ(RejectReport, interpretQueryResult({UNSAT}));
+  ASSERT_EQ(AcceptReport, interpretQueryResult({UNDEF}));
+}
+
+TEST_F(Z3CrosscheckOracleTest, RejectsUNSATs) {
+  ASSERT_EQ(RejectReport, interpretQueryResult({UNSAT}));
+  ASSERT_EQ(RejectReport, interpretQueryResult({UNSAT}));
+  ASSERT_EQ(RejectReport, interpretQueryResult({UNSAT}));
+  ASSERT_EQ(RejectReport, interpretQueryResult({UNSAT}));
+}
+
+} // namespace
diff --git a/llvm/include/llvm/Support/SMTAPI.h b/llvm/include/llvm/Support/SMTAPI.h
index 9389c96956dd1..a2a89674414f4 100644
--- a/llvm/include/llvm/Support/SMTAPI.h
+++ b/llvm/include/llvm/Support/SMTAPI.h
@@ -125,6 +125,19 @@ class SMTExpr {
   virtual bool equal_to(SMTExpr const &other) const = 0;
 };
 
+class SMTSolverStatistics {
+public:
+  SMTSolverStatistics() = default;
+  virtual ~SMTSolverStatistics() = default;
+
+  virtual double getDouble(llvm::StringRef) const = 0;
+  virtual unsigned getUnsigned(llvm::StringRef) const = 0;
+
+  virtual void print(raw_ostream &OS) const = 0;
+
+  LLVM_DUMP_METHOD void dump() const;
+};
+
 /// Shared pointer for SMTExprs, used by SMTSolver API.
 using SMTExprRef = const SMTExpr *;
 
@@ -434,6 +447,12 @@ class SMTSolver {
   virtual bool isFPSupported() = 0;
 
   virtual void print(raw_ostream &OS) const = 0;
+
+  /// Sets the requested option.
+  virtual void setBoolParam(StringRef Key, bool Value) = 0;
+  virtual void setUnsignedParam(StringRef Key, unsigned Value) = 0;
+
+  virtual std::unique_ptr<SMTSolverStatistics> getStatistics() const = 0;
 };
 
 /// Shared pointer for SMTSolvers.
diff --git a/llvm/lib/Support/Z3Solver.cpp b/llvm/lib/Support/Z3Solver.cpp
index eb671fe2596db..5a34ff160f6cf 100644
--- a/llvm/lib/Support/Z3Solver.cpp
+++ b/llvm/lib/Support/Z3Solver.cpp
@@ -6,7 +6,9 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/Config/config.h"
+#include "llvm/Support/NativeFormatting.h"
 #include "llvm/Support/SMTAPI.h"
 
 using namespace llvm;
@@ -26,18 +28,14 @@ namespace {
 class Z3Config {
   friend class Z3Context;
 
-  Z3_config Config;
+  Z3_config Config = Z3_mk_config();
 
 public:
-  Z3Config() : Config(Z3_mk_config()) {
-    // Enable model finding
-    Z3_set_param_value(Config, "model", "true");
-    // Disable proof generation
-    Z3_set_param_value(Config, "proof", "false");
-    // Set timeout to 15000ms = 15s
-    Z3_set_param_value(Config, "timeout", "15000");
-  }
-
+  Z3Config() = default;
+  Z3Config(const Z3Config &) = delete;
+  Z3Config(Z3Config &&) = default;
+  Z3Config &operator=(Z3Config &) = delete;
+  Z3Config &operator=(Z3Config &&) = default;
   ~Z3Config() { Z3_del_config(Config); }
 }; // end class Z3Config
 
@@ -50,16 +48,22 @@ void Z3ErrorHandler(Z3_context Context, Z3_error_code Error) {
 /// Wrapper for Z3 context
 class Z3Context {
 public:
+  Z3Config Config;
   Z3_context Context;
 
   Z3Context() {
-    Context = Z3_mk_context_rc(Z3Config().Config);
+    Context = Z3_mk_context_rc(Config.Config);
     // The error function is set here because the context is the first object
     // created by the backend
     Z3_set_error_handler(Context, Z3ErrorHandler);
   }
 
-  virtual ~Z3Context() {
+  Z3Context(const Z3Context &) = delete;
+  Z3Context(Z3Context &&) = default;
+  Z3Context &operator=(Z3Context &) = delete;
+  Z3Context &operator=(Z3Context &&) = default;
+
+  ~Z3Context() {
     Z3_del_context(Context);
     Context = nullptr;
   }
@@ -262,7 +266,17 @@ class Z3Solver : public SMTSolver {
 
   Z3Context Context;
 
-  Z3_solver Solver;
+  Z3_solver Solver = [this] {
+    Z3_solver S = Z3_mk_simple_solver(Context.Context);
+    Z3_solver_inc_ref(Context.Context, S);
+    return S;
+  }();
+
+  Z3_params Params = [this] {
+    Z3_params P = Z3_mk_params(Context.Context);
+    Z3_params_inc_ref(Context.Context, P);
+    return P;
+  }();
 
   // Cache Sorts
   std::set<Z3Sort> CachedSorts;
@@ -271,18 +285,15 @@ class Z3Solver : public SMTSolver {
   std::set<Z3Expr> CachedExprs;
 
 public:
-  Z3Solver() : Solver(Z3_mk_simple_solver(Context.Context)) {
-    Z3_solver_inc_ref(Context.Context, Solver);
-  }
-
+  Z3Solver() = default;
   Z3Solver(const Z3Solver &Other) = delete;
   Z3Solver(Z3Solver &&Other) = delete;
   Z3Solver &operator=(Z3Solver &Other) = delete;
   Z3Solver &operator=(Z3Solver &&Other) = delete;
 
-  ~Z3Solver() {
-    if (Solver)
-      Z3_solver_dec_ref(Context.Context, Solver);
+  ~Z3Solver() override {
+    Z3_params_dec_ref(Context.Context, Params);
+    Z3_solver_dec_ref(Context.Context, Solver);
   }
 
   void addConstraint(const SMTExprRef &Exp) const override {
@@ -871,6 +882,7 @@ class Z3Solver : public SMTSolver {
   }
 
   std::optional<bool> check() const override {
+    Z3_solver_set_params(Context.Context, Solver, Params);
     Z3_lbool res = Z3_solver_check(Context.Context, Solver);
     if (res == Z3_L_TRUE)
       return true;
@@ -896,8 +908,71 @@ class Z3Solver : public SMTSolver {
   void print(raw_ostream &OS) const override {
     OS << Z3_solver_to_string(Context.Context, Solver);
   }
+
+  void setUnsignedParam(StringRef Key, unsigned Value) override {
+    Z3_symbol Sym = Z3_mk_string_symbol(Context.Context, Key.str().c_str());
+    Z3_params_set_uint(Context.Context, Params, Sym, Value);
+  }
+
+  void setBoolParam(StringRef Key, bool Value) override {
+    Z3_symbol Sym = Z3_mk_string_symbol(Context.Context, Key.str().c_str());
+    Z3_params_set_bool(Context.Context, Params, Sym, Value);
+  }
+
+  std::unique_ptr<SMTSolverStatistics> getStatistics() const override;
 }; // end class Z3Solver
 
+class Z3Statistics final : public SMTSolverStatistics {
+public:
+  double getDouble(StringRef Key) const override {
+    auto It = DoubleValues.find(Key.str());
+    assert(It != DoubleValues.end());
+    return It->second;
+  };
+  unsigned getUnsigned(StringRef Key) const override {
+    auto It = UnsignedValues.find(Key.str());
+    assert(It != UnsignedValues.end());
+    return It->second;
+  };
+
+  void print(raw_ostream &OS) const override {
+    for (auto const &[K, V] : UnsignedValues) {
+      OS << K << ": " << V << '\n';
+    }
+    for (auto const &[K, V] : DoubleValues) {
+      write_double(OS << K << ": ", V, FloatStyle::Fixed);
+      OS << '\n';
+    }
+  }
+
+private:
+  friend class Z3Solver;
+  std::unordered_map<std::string, unsigned> UnsignedValues;
+  std::unordered_map<std::string, double> DoubleValues;
+};
+
+std::unique_ptr<SMTSolverStatistics> Z3Solver::getStatistics() const {
+  auto const &C = Context.Context;
+  Z3_stats S = Z3_solver_get_statistics(C, Solver);
+  Z3_stats_inc_ref(C, S);
+  auto StatsGuard = llvm::make_scope_exit([&C, &S] { Z3_stats_dec_ref(C, S); });
+  Z3Statistics Result;
+
+  unsigned NumKeys = Z3_stats_size(C, S);
+  for (unsigned Idx = 0; Idx < NumKeys; ++Idx) {
+    const char *Key = Z3_stats_get_key(C, S, Idx);
+    if (Z3_stats_is_uint(C, S, Idx)) {
+      auto Value = Z3_stats_get_uint_value(C, S, Idx);
+      Result.UnsignedValues.try_emplace(Key, Value);
+    } else {
+      assert(Z3_stats_is_double(C, S, Idx));
+      auto Value = Z3_stats_get_double_value(C, S, Idx);
+      Result.DoubleValues.try_emplace(Key, Value);
+    }
+  }
+  return std::make_unique<Z3Statistics>(std::move(Result));
+}
+
 } // end anonymous namespace
 
 #endif
@@ -916,3 +991,4 @@ llvm::SMTSolverRef llvm::CreateZ3Solver() {
 LLVM_DUMP_METHOD void SMTSort::dump() const { print(llvm::errs()); }
 LLVM_DUMP_METHOD void SMTExpr::dump() const { print(llvm::errs()); }
 LLVM_DUMP_METHOD void SMTSolver::dump() const { print(llvm::errs()); }
+LLVM_DUMP_METHOD void SMTSolverStatistics::dump() const { print(llvm::errs()); }



More information about the cfe-commits mailing list