[clang] [clang][Interp] Three-way comparisons (PR #65901)
Timm Baeder via cfe-commits
cfe-commits at lists.llvm.org
Tue Sep 26 21:53:07 PDT 2023
https://github.com/tbaederr updated https://github.com/llvm/llvm-project/pull/65901
>From 5dc457a45bdf8366a5f8e5d7df7bcf3383120ecd Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Timm=20B=C3=A4der?= <tbaeder at redhat.com>
Date: Sun, 10 Sep 2023 17:02:22 +0200
Subject: [PATCH] [clang][Interp] Three-way comparisons
---
clang/lib/AST/Interp/Boolean.h | 6 +++
clang/lib/AST/Interp/ByteCodeExprGen.cpp | 23 ++++++++++
clang/lib/AST/Interp/Floating.h | 6 +++
clang/lib/AST/Interp/Integral.h | 7 +++
clang/lib/AST/Interp/Interp.h | 29 +++++++++++++
clang/lib/AST/Interp/InterpBuiltin.cpp | 17 ++++++++
clang/lib/AST/Interp/Opcodes.td | 5 +++
clang/lib/AST/Interp/Pointer.h | 13 ++++++
clang/test/AST/Interp/cxx20.cpp | 54 ++++++++++++++++++++++++
9 files changed, 160 insertions(+)
diff --git a/clang/lib/AST/Interp/Boolean.h b/clang/lib/AST/Interp/Boolean.h
index 6f0fe26ace68807..c3ed3d61f76ca1c 100644
--- a/clang/lib/AST/Interp/Boolean.h
+++ b/clang/lib/AST/Interp/Boolean.h
@@ -84,6 +84,12 @@ class Boolean final {
Boolean truncate(unsigned TruncBits) const { return *this; }
void print(llvm::raw_ostream &OS) const { OS << (V ? "true" : "false"); }
+ std::string toDiagnosticString(const ASTContext &Ctx) const {
+ std::string NameStr;
+ llvm::raw_string_ostream OS(NameStr);
+ print(OS);
+ return NameStr;
+ }
static Boolean min(unsigned NumBits) { return Boolean(false); }
static Boolean max(unsigned NumBits) { return Boolean(true); }
diff --git a/clang/lib/AST/Interp/ByteCodeExprGen.cpp b/clang/lib/AST/Interp/ByteCodeExprGen.cpp
index e813d4fa651ceaf..a09e2a007b912c9 100644
--- a/clang/lib/AST/Interp/ByteCodeExprGen.cpp
+++ b/clang/lib/AST/Interp/ByteCodeExprGen.cpp
@@ -253,6 +253,29 @@ bool ByteCodeExprGen<Emitter>::VisitBinaryOperator(const BinaryOperator *BO) {
return this->delegate(RHS);
}
+ // Special case for C++'s three-way/spaceship operator <=>, which
+ // returns a std::{strong,weak,partial}_ordering (which is a class, so doesn't
+ // have a PrimType).
+ if (!T) {
+ if (DiscardResult)
+ return true;
+ const ComparisonCategoryInfo *CmpInfo =
+ Ctx.getASTContext().CompCategories.lookupInfoForType(BO->getType());
+ assert(CmpInfo);
+
+ // We need a temporary variable holding our return value.
+ if (!Initializing) {
+ std::optional<unsigned> ResultIndex = this->allocateLocal(BO, false);
+ if (!this->emitGetPtrLocal(*ResultIndex, BO))
+ return false;
+ }
+
+ if (!visit(LHS) || !visit(RHS))
+ return false;
+
+ return this->emitCMP3(*LT, CmpInfo, BO);
+ }
+
if (!LT || !RT || !T)
return this->bail(BO);
diff --git a/clang/lib/AST/Interp/Floating.h b/clang/lib/AST/Interp/Floating.h
index 9a8fd34ec934893..a22b3fa79f3992f 100644
--- a/clang/lib/AST/Interp/Floating.h
+++ b/clang/lib/AST/Interp/Floating.h
@@ -76,6 +76,12 @@ class Floating final {
F.toString(Buffer);
OS << Buffer;
}
+ std::string toDiagnosticString(const ASTContext &Ctx) const {
+ std::string NameStr;
+ llvm::raw_string_ostream OS(NameStr);
+ print(OS);
+ return NameStr;
+ }
unsigned bitWidth() const { return F.semanticsSizeInBits(F.getSemantics()); }
diff --git a/clang/lib/AST/Interp/Integral.h b/clang/lib/AST/Interp/Integral.h
index 72285cabcbbf8ce..0295a9c3b5c898c 100644
--- a/clang/lib/AST/Interp/Integral.h
+++ b/clang/lib/AST/Interp/Integral.h
@@ -128,6 +128,13 @@ template <unsigned Bits, bool Signed> class Integral final {
return Compare(V, RHS.V);
}
+ std::string toDiagnosticString(const ASTContext &Ctx) const {
+ std::string NameStr;
+ llvm::raw_string_ostream OS(NameStr);
+ OS << V;
+ return NameStr;
+ }
+
unsigned countLeadingZeros() const {
if constexpr (!Signed)
return llvm::countl_zero<ReprT>(V);
diff --git a/clang/lib/AST/Interp/Interp.h b/clang/lib/AST/Interp/Interp.h
index 8453856e526a6b2..dd37150b63f6db0 100644
--- a/clang/lib/AST/Interp/Interp.h
+++ b/clang/lib/AST/Interp/Interp.h
@@ -112,6 +112,11 @@ bool CheckCtorCall(InterpState &S, CodePtr OpPC, const Pointer &This);
bool CheckPotentialReinterpretCast(InterpState &S, CodePtr OpPC,
const Pointer &Ptr);
+/// Sets the given integral value to the pointer, which is of
+/// a std::{weak,partial,strong}_ordering type.
+bool SetThreeWayComparisonField(InterpState &S, CodePtr OpPC,
+ const Pointer &Ptr, const APSInt &IntValue);
+
/// Checks if the shift operation is legal.
template <typename LT, typename RT>
bool CheckShift(InterpState &S, CodePtr OpPC, const LT &LHS, const RT &RHS,
@@ -781,6 +786,30 @@ bool EQ(InterpState &S, CodePtr OpPC) {
});
}
+template <PrimType Name, class T = typename PrimConv<Name>::T>
+bool CMP3(InterpState &S, CodePtr OpPC, const ComparisonCategoryInfo *CmpInfo) {
+ const T &RHS = S.Stk.pop<T>();
+ const T &LHS = S.Stk.pop<T>();
+ const Pointer &P = S.Stk.peek<Pointer>();
+
+ ComparisonCategoryResult CmpResult = LHS.compare(RHS);
+ if (CmpResult == ComparisonCategoryResult::Unordered) {
+ // This should only happen with pointers.
+ const SourceInfo &Loc = S.Current->getSource(OpPC);
+ S.FFDiag(Loc, diag::note_constexpr_pointer_comparison_unspecified)
+ << LHS.toDiagnosticString(S.getCtx())
+ << RHS.toDiagnosticString(S.getCtx());
+ return false;
+ }
+
+ assert(CmpInfo);
+ const auto *CmpValueInfo = CmpInfo->getValueInfo(CmpResult);
+ assert(CmpValueInfo);
+ assert(CmpValueInfo->hasValidIntValue());
+ APSInt IntValue = CmpValueInfo->getIntValue();
+ return SetThreeWayComparisonField(S, OpPC, P, IntValue);
+}
+
template <PrimType Name, class T = typename PrimConv<Name>::T>
bool NE(InterpState &S, CodePtr OpPC) {
return CmpHelperEQ<T>(S, OpPC, [](ComparisonCategoryResult R) {
diff --git a/clang/lib/AST/Interp/InterpBuiltin.cpp b/clang/lib/AST/Interp/InterpBuiltin.cpp
index 4536e335bf1a162..d816145598049b0 100644
--- a/clang/lib/AST/Interp/InterpBuiltin.cpp
+++ b/clang/lib/AST/Interp/InterpBuiltin.cpp
@@ -594,5 +594,22 @@ bool InterpretOffsetOf(InterpState &S, CodePtr OpPC, const OffsetOfExpr *E,
return true;
}
+bool SetThreeWayComparisonField(InterpState &S, CodePtr OpPC,
+ const Pointer &Ptr, const APSInt &IntValue) {
+
+ const Record *R = Ptr.getRecord();
+ assert(R);
+ assert(R->getNumFields() == 1);
+
+ unsigned FieldOffset = R->getField(0u)->Offset;
+ const Pointer &FieldPtr = Ptr.atField(FieldOffset);
+ PrimType FieldT = *S.getContext().classify(FieldPtr.getType());
+
+ INT_TYPE_SWITCH(FieldT,
+ FieldPtr.deref<T>() = T::from(IntValue.getSExtValue()));
+ FieldPtr.initialize();
+ return true;
+}
+
} // namespace interp
} // namespace clang
diff --git a/clang/lib/AST/Interp/Opcodes.td b/clang/lib/AST/Interp/Opcodes.td
index eeb71db125fef73..0ce64b769b01fb7 100644
--- a/clang/lib/AST/Interp/Opcodes.td
+++ b/clang/lib/AST/Interp/Opcodes.td
@@ -55,6 +55,7 @@ def ArgCastKind : ArgType { let Name = "CastKind"; }
def ArgCallExpr : ArgType { let Name = "const CallExpr *"; }
def ArgOffsetOfExpr : ArgType { let Name = "const OffsetOfExpr *"; }
def ArgDeclRef : ArgType { let Name = "const DeclRefExpr *"; }
+def ArgCCI : ArgType { let Name = "const ComparisonCategoryInfo *"; }
//===----------------------------------------------------------------------===//
// Classes of types instructions operate on.
@@ -607,6 +608,10 @@ class ComparisonOpcode : Opcode {
let HasGroup = 1;
}
+def CMP3 : ComparisonOpcode {
+ let Args = [ArgCCI];
+}
+
def LT : ComparisonOpcode;
def LE : ComparisonOpcode;
def GT : ComparisonOpcode;
diff --git a/clang/lib/AST/Interp/Pointer.h b/clang/lib/AST/Interp/Pointer.h
index 3834237f11d1314..8c97a965320e106 100644
--- a/clang/lib/AST/Interp/Pointer.h
+++ b/clang/lib/AST/Interp/Pointer.h
@@ -362,6 +362,19 @@ class Pointer {
/// Deactivates an entire strurcutre.
void deactivate() const;
+ /// Compare two pointers.
+ ComparisonCategoryResult compare(const Pointer &Other) const {
+ if (!hasSameBase(*this, Other))
+ return ComparisonCategoryResult::Unordered;
+
+ if (Offset < Other.Offset)
+ return ComparisonCategoryResult::Less;
+ else if (Offset > Other.Offset)
+ return ComparisonCategoryResult::Greater;
+
+ return ComparisonCategoryResult::Equal;
+ }
+
/// Checks if two pointers are comparable.
static bool hasSameBase(const Pointer &A, const Pointer &B);
/// Checks if two pointers can be subtracted.
diff --git a/clang/test/AST/Interp/cxx20.cpp b/clang/test/AST/Interp/cxx20.cpp
index df08bb75199d86a..0b13f41270a95b8 100644
--- a/clang/test/AST/Interp/cxx20.cpp
+++ b/clang/test/AST/Interp/cxx20.cpp
@@ -646,3 +646,57 @@ namespace ImplicitFunction {
// expected-error {{not an integral constant expression}} \
// expected-note {{in call to 'callMe()'}}
}
+
+/// FIXME: Unfortunately, the similar tests in test/SemaCXX/{compare-cxx2a.cpp use member pointers,
+/// which we don't support yet.
+namespace std {
+ class strong_ordering {
+ public:
+ int n;
+ static const strong_ordering less, equal, greater;
+ constexpr bool operator==(int n) const noexcept { return this->n == n;}
+ constexpr bool operator!=(int n) const noexcept { return this->n != n;}
+ };
+ constexpr strong_ordering strong_ordering::less = {-1};
+ constexpr strong_ordering strong_ordering::equal = {0};
+ constexpr strong_ordering strong_ordering::greater = {1};
+
+ class partial_ordering {
+ public:
+ long n;
+ static const partial_ordering less, equal, greater, equivalent, unordered;
+ constexpr bool operator==(long n) const noexcept { return this->n == n;}
+ constexpr bool operator!=(long n) const noexcept { return this->n != n;}
+ };
+ constexpr partial_ordering partial_ordering::less = {-1};
+ constexpr partial_ordering partial_ordering::equal = {0};
+ constexpr partial_ordering partial_ordering::greater = {1};
+ constexpr partial_ordering partial_ordering::equivalent = {0};
+ constexpr partial_ordering partial_ordering::unordered = {-127};
+} // namespace std
+
+namespace ThreeWayCmp {
+ static_assert(1 <=> 2 == -1, "");
+ static_assert(1 <=> 1 == 0, "");
+ static_assert(2 <=> 1 == 1, "");
+ static_assert(1.0 <=> 2.f == -1, "");
+ static_assert(1.0 <=> 1.0 == 0, "");
+ static_assert(2.0 <=> 1.0 == 1, "");
+ constexpr int k = (1 <=> 1, 0); // expected-warning {{comparison result unused}} \
+ // ref-warning {{comparison result unused}}
+ static_assert(k== 0, "");
+
+ /// Pointers.
+ constexpr int a[] = {1,2,3};
+ constexpr int b[] = {1,2,3};
+ constexpr const int *pa1 = &a[1];
+ constexpr const int *pa2 = &a[2];
+ constexpr const int *pb1 = &b[1];
+ static_assert(pa1 <=> pb1 != 0, ""); // expected-error {{not an integral constant expression}} \
+ // expected-note {{has unspecified value}} \
+ // ref-error {{not an integral constant expression}} \
+ // ref-note {{has unspecified value}}
+ static_assert(pa1 <=> pa1 == 0, "");
+ static_assert(pa1 <=> pa2 == -1, "");
+ static_assert(pa2 <=> pa1 == 1, "");
+}
More information about the cfe-commits
mailing list