[clang] [clang][Interp] Three-way comparisons (PR #65901)

Timm Baeder via cfe-commits cfe-commits at lists.llvm.org
Sun Sep 10 08:05:20 PDT 2023


https://github.com/tbaederr created https://github.com/llvm/llvm-project/pull/65901:

None

>From a8503c150e80df7c505eb0f82941f2fcebe8a97e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Timm=20B=C3=A4der?= <tbaeder at redhat.com>
Date: Sun, 10 Sep 2023 17:02:22 +0200
Subject: [PATCH] [clang][Interp] Three-way comparisons

---
 clang/lib/AST/Interp/Boolean.h           |  6 +++
 clang/lib/AST/Interp/ByteCodeExprGen.cpp | 23 ++++++++++
 clang/lib/AST/Interp/Floating.h          |  6 +++
 clang/lib/AST/Interp/Integral.h          |  7 +++
 clang/lib/AST/Interp/Interp.h            | 29 +++++++++++++
 clang/lib/AST/Interp/InterpBuiltin.cpp   | 17 ++++++++
 clang/lib/AST/Interp/Opcodes.td          |  5 +++
 clang/lib/AST/Interp/Pointer.h           | 13 ++++++
 clang/test/AST/Interp/cxx20.cpp          | 54 ++++++++++++++++++++++++
 9 files changed, 160 insertions(+)

diff --git a/clang/lib/AST/Interp/Boolean.h b/clang/lib/AST/Interp/Boolean.h
index 6f0fe26ace68807..c3ed3d61f76ca1c 100644
--- a/clang/lib/AST/Interp/Boolean.h
+++ b/clang/lib/AST/Interp/Boolean.h
@@ -84,6 +84,12 @@ class Boolean final {
   Boolean truncate(unsigned TruncBits) const { return *this; }
 
   void print(llvm::raw_ostream &OS) const { OS << (V ? "true" : "false"); }
+  std::string toDiagnosticString(const ASTContext &Ctx) const {
+    std::string NameStr;
+    llvm::raw_string_ostream OS(NameStr);
+    print(OS);
+    return NameStr;
+  }
 
   static Boolean min(unsigned NumBits) { return Boolean(false); }
   static Boolean max(unsigned NumBits) { return Boolean(true); }
diff --git a/clang/lib/AST/Interp/ByteCodeExprGen.cpp b/clang/lib/AST/Interp/ByteCodeExprGen.cpp
index 6a492c4c907cde0..47e070285aebd67 100644
--- a/clang/lib/AST/Interp/ByteCodeExprGen.cpp
+++ b/clang/lib/AST/Interp/ByteCodeExprGen.cpp
@@ -261,6 +261,29 @@ bool ByteCodeExprGen<Emitter>::VisitBinaryOperator(const BinaryOperator *BO) {
                                 : this->visit(RHS));
   }
 
+  // Special case for C++'s three-way/spaceship operator <=>, which
+  // returns a std::strong_ordering (which is class, so doesn't have a
+  // PrimType).
+  if (!T) {
+    if (DiscardResult)
+      return true;
+    const ComparisonCategoryInfo *CmpInfo =
+        Ctx.getASTContext().CompCategories.lookupInfoForType(BO->getType());
+    assert(CmpInfo);
+
+    // We need a temporary variable holding our return value.
+    if (!Initializing) {
+      std::optional<unsigned> ResultIndex = this->allocateLocal(BO, false);
+      if (!this->emitGetPtrLocal(*ResultIndex, BO))
+        return false;
+    }
+
+    if (!visit(LHS) || !visit(RHS))
+      return false;
+
+    return this->emitCMP3(*LT, CmpInfo, BO);
+  }
+
   if (!LT || !RT || !T)
     return this->bail(BO);
 
diff --git a/clang/lib/AST/Interp/Floating.h b/clang/lib/AST/Interp/Floating.h
index 9a8fd34ec934893..a22b3fa79f3992f 100644
--- a/clang/lib/AST/Interp/Floating.h
+++ b/clang/lib/AST/Interp/Floating.h
@@ -76,6 +76,12 @@ class Floating final {
     F.toString(Buffer);
     OS << Buffer;
   }
+  std::string toDiagnosticString(const ASTContext &Ctx) const {
+    std::string NameStr;
+    llvm::raw_string_ostream OS(NameStr);
+    print(OS);
+    return NameStr;
+  }
 
   unsigned bitWidth() const { return F.semanticsSizeInBits(F.getSemantics()); }
 
diff --git a/clang/lib/AST/Interp/Integral.h b/clang/lib/AST/Interp/Integral.h
index 72285cabcbbf8ce..0295a9c3b5c898c 100644
--- a/clang/lib/AST/Interp/Integral.h
+++ b/clang/lib/AST/Interp/Integral.h
@@ -128,6 +128,13 @@ template <unsigned Bits, bool Signed> class Integral final {
     return Compare(V, RHS.V);
   }
 
+  std::string toDiagnosticString(const ASTContext &Ctx) const {
+    std::string NameStr;
+    llvm::raw_string_ostream OS(NameStr);
+    OS << V;
+    return NameStr;
+  }
+
   unsigned countLeadingZeros() const {
     if constexpr (!Signed)
       return llvm::countl_zero<ReprT>(V);
diff --git a/clang/lib/AST/Interp/Interp.h b/clang/lib/AST/Interp/Interp.h
index 5006f72fe7237f5..14f23e84386d0fd 100644
--- a/clang/lib/AST/Interp/Interp.h
+++ b/clang/lib/AST/Interp/Interp.h
@@ -112,6 +112,11 @@ bool CheckCtorCall(InterpState &S, CodePtr OpPC, const Pointer &This);
 bool CheckPotentialReinterpretCast(InterpState &S, CodePtr OpPC,
                                    const Pointer &Ptr);
 
+/// Sets the given integral value to the pointer, which is of
+/// a std::{weak,partial,strong}_ordering type.
+bool SetThreeWayComparisonField(InterpState &S, CodePtr OpPC,
+                                const Pointer &Ptr, const APSInt &IntValue);
+
 /// Checks if the shift operation is legal.
 template <typename LT, typename RT>
 bool CheckShift(InterpState &S, CodePtr OpPC, const LT &LHS, const RT &RHS,
@@ -773,6 +778,30 @@ bool EQ(InterpState &S, CodePtr OpPC) {
   });
 }
 
+template <PrimType Name, class T = typename PrimConv<Name>::T>
+bool CMP3(InterpState &S, CodePtr OpPC, const ComparisonCategoryInfo *CmpInfo) {
+  const T &RHS = S.Stk.pop<T>();
+  const T &LHS = S.Stk.pop<T>();
+  const Pointer &P = S.Stk.peek<Pointer>();
+
+  ComparisonCategoryResult CmpResult = LHS.compare(RHS);
+  if (CmpResult == ComparisonCategoryResult::Unordered) {
+    // This should only happen with pointers.
+    const SourceInfo &Loc = S.Current->getSource(OpPC);
+    S.FFDiag(Loc, diag::note_constexpr_pointer_comparison_unspecified)
+        << LHS.toDiagnosticString(S.getCtx())
+        << RHS.toDiagnosticString(S.getCtx());
+    return false;
+  }
+
+  assert(CmpInfo);
+  const auto *CmpValueInfo = CmpInfo->getValueInfo(CmpResult);
+  assert(CmpValueInfo);
+  assert(CmpValueInfo->hasValidIntValue());
+  APSInt IntValue = CmpValueInfo->getIntValue();
+  return SetThreeWayComparisonField(S, OpPC, P, IntValue);
+}
+
 template <PrimType Name, class T = typename PrimConv<Name>::T>
 bool NE(InterpState &S, CodePtr OpPC) {
   return CmpHelperEQ<T>(S, OpPC, [](ComparisonCategoryResult R) {
diff --git a/clang/lib/AST/Interp/InterpBuiltin.cpp b/clang/lib/AST/Interp/InterpBuiltin.cpp
index c607368f3b65824..f6254696844f8de 100644
--- a/clang/lib/AST/Interp/InterpBuiltin.cpp
+++ b/clang/lib/AST/Interp/InterpBuiltin.cpp
@@ -519,5 +519,22 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const Function *F,
   return false;
 }
 
+bool SetThreeWayComparisonField(InterpState &S, CodePtr OpPC,
+                                const Pointer &Ptr, const APSInt &IntValue) {
+
+  const Record *R = Ptr.getRecord();
+  assert(R);
+  assert(R->getNumFields() == 1);
+
+  unsigned FieldOffset = R->getField(0u)->Offset;
+  const Pointer &FieldPtr = Ptr.atField(FieldOffset);
+  PrimType FieldT = *S.getContext().classify(FieldPtr.getType());
+
+  INT_TYPE_SWITCH(FieldT,
+                  FieldPtr.deref<T>() = T::from(IntValue.getSExtValue()));
+  FieldPtr.initialize();
+  return true;
+}
+
 } // namespace interp
 } // namespace clang
diff --git a/clang/lib/AST/Interp/Opcodes.td b/clang/lib/AST/Interp/Opcodes.td
index 8bdc4432e89b410..18b78fa42244c9c 100644
--- a/clang/lib/AST/Interp/Opcodes.td
+++ b/clang/lib/AST/Interp/Opcodes.td
@@ -53,6 +53,7 @@ def ArgRoundingMode : ArgType { let Name = "llvm::RoundingMode"; }
 def ArgLETD: ArgType { let Name = "const LifetimeExtendedTemporaryDecl *"; }
 def ArgCastKind : ArgType { let Name = "CastKind"; }
 def ArgCallExpr : ArgType { let Name = "const CallExpr *"; }
+def ArgCCI : ArgType { let Name = "const ComparisonCategoryInfo *"; }
 
 //===----------------------------------------------------------------------===//
 // Classes of types instructions operate on.
@@ -599,6 +600,10 @@ class ComparisonOpcode : Opcode {
   let HasGroup = 1;
 }
 
+def CMP3 : ComparisonOpcode {
+  let Args = [ArgCCI];
+}
+
 def LT : ComparisonOpcode;
 def LE : ComparisonOpcode;
 def GT : ComparisonOpcode;
diff --git a/clang/lib/AST/Interp/Pointer.h b/clang/lib/AST/Interp/Pointer.h
index 3834237f11d1314..8c97a965320e106 100644
--- a/clang/lib/AST/Interp/Pointer.h
+++ b/clang/lib/AST/Interp/Pointer.h
@@ -362,6 +362,19 @@ class Pointer {
   /// Deactivates an entire strurcutre.
   void deactivate() const;
 
+  /// Compare two pointers.
+  ComparisonCategoryResult compare(const Pointer &Other) const {
+    if (!hasSameBase(*this, Other))
+      return ComparisonCategoryResult::Unordered;
+
+    if (Offset < Other.Offset)
+      return ComparisonCategoryResult::Less;
+    else if (Offset > Other.Offset)
+      return ComparisonCategoryResult::Greater;
+
+    return ComparisonCategoryResult::Equal;
+  }
+
   /// Checks if two pointers are comparable.
   static bool hasSameBase(const Pointer &A, const Pointer &B);
   /// Checks if two pointers can be subtracted.
diff --git a/clang/test/AST/Interp/cxx20.cpp b/clang/test/AST/Interp/cxx20.cpp
index df08bb75199d86a..0b13f41270a95b8 100644
--- a/clang/test/AST/Interp/cxx20.cpp
+++ b/clang/test/AST/Interp/cxx20.cpp
@@ -646,3 +646,57 @@ namespace ImplicitFunction {
                                     // expected-error {{not an integral constant expression}} \
                                     // expected-note {{in call to 'callMe()'}}
 }
+
+/// FIXME: Unfortunately, the similar tests in test/SemaCXX/{compare-cxx2a.cpp use member pointers,
+/// which we don't support yet.
+namespace std {
+  class strong_ordering {
+  public:
+    int n;
+    static const strong_ordering less, equal, greater;
+    constexpr bool operator==(int n) const noexcept { return this->n == n;}
+    constexpr bool operator!=(int n) const noexcept { return this->n != n;}
+  };
+  constexpr strong_ordering strong_ordering::less = {-1};
+  constexpr strong_ordering strong_ordering::equal = {0};
+  constexpr strong_ordering strong_ordering::greater = {1};
+
+  class partial_ordering {
+  public:
+    long n;
+    static const partial_ordering less, equal, greater, equivalent, unordered;
+    constexpr bool operator==(long n) const noexcept { return this->n == n;}
+    constexpr bool operator!=(long n) const noexcept { return this->n != n;}
+  };
+  constexpr partial_ordering partial_ordering::less = {-1};
+  constexpr partial_ordering partial_ordering::equal = {0};
+  constexpr partial_ordering partial_ordering::greater = {1};
+  constexpr partial_ordering partial_ordering::equivalent = {0};
+  constexpr partial_ordering partial_ordering::unordered = {-127};
+} // namespace std
+
+namespace ThreeWayCmp {
+  static_assert(1 <=> 2 == -1, "");
+  static_assert(1 <=> 1 == 0, "");
+  static_assert(2 <=> 1 == 1, "");
+  static_assert(1.0 <=> 2.f == -1, "");
+  static_assert(1.0 <=> 1.0 == 0, "");
+  static_assert(2.0 <=> 1.0 == 1, "");
+  constexpr int k = (1 <=> 1, 0); // expected-warning {{comparison result unused}} \
+                                  // ref-warning {{comparison result unused}}
+  static_assert(k== 0, "");
+
+  /// Pointers.
+  constexpr int a[] = {1,2,3};
+  constexpr int b[] = {1,2,3};
+  constexpr const int *pa1 = &a[1];
+  constexpr const int *pa2 = &a[2];
+  constexpr const int *pb1 = &b[1];
+  static_assert(pa1 <=> pb1 != 0, ""); // expected-error {{not an integral constant expression}} \
+                                       // expected-note {{has unspecified value}} \
+                                       // ref-error {{not an integral constant expression}} \
+                                       // ref-note {{has unspecified value}}
+  static_assert(pa1 <=> pa1 == 0, "");
+  static_assert(pa1 <=> pa2 == -1, "");
+  static_assert(pa2 <=> pa1 == 1, "");
+}



More information about the cfe-commits mailing list