[clang] 1542601 - [clang][Interp] Handle non-complex operands in complex bin ops
Timm Bäder via cfe-commits
cfe-commits at lists.llvm.org
Mon Feb 26 00:09:06 PST 2024
Author: Timm Bäder
Date: 2024-02-26T09:08:50+01:00
New Revision: 15426017bda54fb8d9a62cb887edae754e8b7733
URL: https://github.com/llvm/llvm-project/commit/15426017bda54fb8d9a62cb887edae754e8b7733
DIFF: https://github.com/llvm/llvm-project/commit/15426017bda54fb8d9a62cb887edae754e8b7733.diff
LOG: [clang][Interp] Handle non-complex operands in complex bin ops
Either LHS or RHS might be non-complex, but not both.
Added:
Modified:
clang/lib/AST/Interp/ByteCodeExprGen.cpp
clang/test/AST/Interp/complex.cpp
Removed:
################################################################################
diff --git a/clang/lib/AST/Interp/ByteCodeExprGen.cpp b/clang/lib/AST/Interp/ByteCodeExprGen.cpp
index eb5a1b536b7798..49ba8e95f17995 100644
--- a/clang/lib/AST/Interp/ByteCodeExprGen.cpp
+++ b/clang/lib/AST/Interp/ByteCodeExprGen.cpp
@@ -660,19 +660,16 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
return false;
}
+ // Both LHS and RHS might _not_ be of complex type, but one of them
+ // needs to be.
const Expr *LHS = E->getLHS();
const Expr *RHS = E->getRHS();
- PrimType LHSElemT = this->classifyComplexElementType(LHS->getType());
- PrimType RHSElemT = this->classifyComplexElementType(RHS->getType());
- unsigned LHSOffset = this->allocateLocalPrimitive(LHS, PT_Ptr, true, false);
- unsigned RHSOffset = this->allocateLocalPrimitive(RHS, PT_Ptr, true, false);
+ PrimType ResultElemT = this->classifyComplexElementType(E->getType());
unsigned ResultOffset = ~0u;
- if (!this->DiscardResult)
+ if (!DiscardResult)
ResultOffset = this->allocateLocalPrimitive(E, PT_Ptr, true, false);
- assert(LHSElemT == RHSElemT);
-
// Save result pointer in ResultOffset
if (!this->DiscardResult) {
if (!this->emitDupPtr(E))
@@ -682,16 +679,64 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
}
// Evaluate LHS and save value to LHSOffset.
- if (!this->visit(LHS))
- return false;
- if (!this->emitSetLocal(PT_Ptr, LHSOffset, E))
- return false;
+ bool LHSIsComplex;
+ unsigned LHSOffset;
+ if (LHS->getType()->isAnyComplexType()) {
+ LHSIsComplex = true;
+ LHSOffset = this->allocateLocalPrimitive(LHS, PT_Ptr, true, false);
+ if (!this->visit(LHS))
+ return false;
+ if (!this->emitSetLocal(PT_Ptr, LHSOffset, E))
+ return false;
+ } else {
+ LHSIsComplex = false;
+ PrimType LHST = classifyPrim(LHS->getType());
+ LHSOffset = this->allocateLocalPrimitive(LHS, LHST, true, false);
+ if (!this->visit(LHS))
+ return false;
+ if (!this->emitSetLocal(LHST, LHSOffset, E))
+ return false;
+ }
// Same with RHS.
- if (!this->visit(RHS))
- return false;
- if (!this->emitSetLocal(PT_Ptr, RHSOffset, E))
- return false;
+ bool RHSIsComplex;
+ unsigned RHSOffset;
+ if (RHS->getType()->isAnyComplexType()) {
+ RHSIsComplex = true;
+ RHSOffset = this->allocateLocalPrimitive(RHS, PT_Ptr, true, false);
+ if (!this->visit(RHS))
+ return false;
+ if (!this->emitSetLocal(PT_Ptr, RHSOffset, E))
+ return false;
+ } else {
+ RHSIsComplex = false;
+ PrimType RHST = classifyPrim(RHS->getType());
+ RHSOffset = this->allocateLocalPrimitive(RHS, RHST, true, false);
+ if (!this->visit(RHS))
+ return false;
+ if (!this->emitSetLocal(RHST, RHSOffset, E))
+ return false;
+ }
+
+ // For both LHS and RHS, either load the value from the complex pointer, or
+ // directly from the local variable. For index 1 (i.e. the imaginary part),
+ // just load 0 and do the operation anyway.
+ auto loadComplexValue = [this](bool IsComplex, unsigned ElemIndex,
+ unsigned Offset, const Expr *E) -> bool {
+ if (IsComplex) {
+ if (!this->emitGetLocal(PT_Ptr, Offset, E))
+ return false;
+ if (!this->emitConstUint8(ElemIndex, E))
+ return false;
+ if (!this->emitArrayElemPtrPopUint8(E))
+ return false;
+ return this->emitLoadPop(classifyComplexElementType(E->getType()), E);
+ }
+ if (ElemIndex == 0)
+ return this->emitGetLocal(classifyPrim(E->getType()), Offset, E);
+ return this->visitZeroInitializer(classifyPrim(E->getType()), E->getType(),
+ E);
+ };
// Now we can get pointers to the LHS and RHS from the offsets above.
BinaryOperatorKind Op = E->getOpcode();
@@ -702,41 +747,29 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
return false;
}
- if (!this->emitGetLocal(PT_Ptr, LHSOffset, E))
- return false;
- if (!this->emitConstUint8(ElemIndex, E))
- return false;
- if (!this->emitArrayElemPtrPopUint8(E))
- return false;
- if (!this->emitLoadPop(LHSElemT, E))
+ if (!loadComplexValue(LHSIsComplex, ElemIndex, LHSOffset, LHS))
return false;
- if (!this->emitGetLocal(PT_Ptr, RHSOffset, E))
- return false;
- if (!this->emitConstUint8(ElemIndex, E))
- return false;
- if (!this->emitArrayElemPtrPopUint8(E))
- return false;
- if (!this->emitLoadPop(RHSElemT, E))
+ if (!loadComplexValue(RHSIsComplex, ElemIndex, RHSOffset, RHS))
return false;
// The actual operation.
switch (Op) {
case BO_Add:
- if (LHSElemT == PT_Float) {
+ if (ResultElemT == PT_Float) {
if (!this->emitAddf(getRoundingMode(E), E))
return false;
} else {
- if (!this->emitAdd(LHSElemT, E))
+ if (!this->emitAdd(ResultElemT, E))
return false;
}
break;
case BO_Sub:
- if (LHSElemT == PT_Float) {
+ if (ResultElemT == PT_Float) {
if (!this->emitSubf(getRoundingMode(E), E))
return false;
} else {
- if (!this->emitSub(LHSElemT, E))
+ if (!this->emitSub(ResultElemT, E))
return false;
}
break;
@@ -747,10 +780,10 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
if (!this->DiscardResult) {
// Initialize array element with the value we just computed.
- if (!this->emitInitElemPop(LHSElemT, ElemIndex, E))
+ if (!this->emitInitElemPop(ResultElemT, ElemIndex, E))
return false;
} else {
- if (!this->emitPop(LHSElemT, E))
+ if (!this->emitPop(ResultElemT, E))
return false;
}
}
diff --git a/clang/test/AST/Interp/complex.cpp b/clang/test/AST/Interp/complex.cpp
index 612a94b31171b9..2b65ccf9946e70 100644
--- a/clang/test/AST/Interp/complex.cpp
+++ b/clang/test/AST/Interp/complex.cpp
@@ -164,6 +164,20 @@ namespace Add {
constexpr _Complex unsigned int I3 = I1 + I2;
static_assert(__real(I3) == 45, "");
static_assert(__imag(I3) == 12, "");
+
+ static_assert(__real(A + 2.0) == 15, "");
+ static_assert(__imag(A + 2.0) == 2, "");
+ static_assert(__real(2.0 + A) == 15, "");
+ static_assert(__imag(2.0 + A) == 2, "");
+
+ static_assert(__real(D + 1) == 16, "");
+ static_assert(__real(D + 1.0) == 16, "");
+ constexpr _Complex double D2 = D + 3.0;
+ static_assert(__real(D2) == 18.0, "");
+ static_assert(__imag(D2) == 3.0, "");
+ constexpr _Complex double D3 = 3.0 + D;
+ static_assert(__real(D3) == 18.0, "");
+ static_assert(__imag(D3) == 3.0, "");
}
namespace Sub {
@@ -172,6 +186,8 @@ namespace Sub {
constexpr _Complex float C = A - B;
static_assert(__real(C) == 11.0, "");
static_assert(__imag(C) == 1.0, "");
+ static_assert(__real(A - 2.0) == 11, "");
+ static_assert(__real(2.0 - A) == -11, "");
constexpr _Complex float D = B - A;
static_assert(__real(D) == -11.0, "");
@@ -189,6 +205,15 @@ namespace Sub {
constexpr _Complex float D_ = A_ - B_;
static_assert(__real(D_) == 11.0, "");
static_assert(__imag(D_) == 1.0, "");
+
+ static_assert(__real(D - 1) == -12, "");
+ static_assert(__real(D - 1.0) == -12, "");
+ constexpr _Complex double D2 = D - 3.0;
+ static_assert(__real(D2) == -14.0, "");
+ static_assert(__imag(D2) == -1.0, "");
+ constexpr _Complex double D3 = 3.0 - D;
+ static_assert(__real(D3) == 14.0, "");
+ static_assert(__imag(D3) == 1.0, "");
}
}
More information about the cfe-commits
mailing list