[clang] 1542601 - [clang][Interp] Handle non-complex operands in complex bin ops

Timm Bäder via cfe-commits cfe-commits at lists.llvm.org
Mon Feb 26 00:09:06 PST 2024


Author: Timm Bäder
Date: 2024-02-26T09:08:50+01:00
New Revision: 15426017bda54fb8d9a62cb887edae754e8b7733

URL: https://github.com/llvm/llvm-project/commit/15426017bda54fb8d9a62cb887edae754e8b7733
DIFF: https://github.com/llvm/llvm-project/commit/15426017bda54fb8d9a62cb887edae754e8b7733.diff

LOG: [clang][Interp] Handle non-complex operands in complex bin ops

Either LHS or RHS might be non-complex, but not both.

Added: 
    

Modified: 
    clang/lib/AST/Interp/ByteCodeExprGen.cpp
    clang/test/AST/Interp/complex.cpp

Removed: 
    


################################################################################
diff  --git a/clang/lib/AST/Interp/ByteCodeExprGen.cpp b/clang/lib/AST/Interp/ByteCodeExprGen.cpp
index eb5a1b536b7798..49ba8e95f17995 100644
--- a/clang/lib/AST/Interp/ByteCodeExprGen.cpp
+++ b/clang/lib/AST/Interp/ByteCodeExprGen.cpp
@@ -660,19 +660,16 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
       return false;
   }
 
+  // Both LHS and RHS might _not_ be of complex type, but one of them
+  // needs to be.
   const Expr *LHS = E->getLHS();
   const Expr *RHS = E->getRHS();
-  PrimType LHSElemT = this->classifyComplexElementType(LHS->getType());
-  PrimType RHSElemT = this->classifyComplexElementType(RHS->getType());
 
-  unsigned LHSOffset = this->allocateLocalPrimitive(LHS, PT_Ptr, true, false);
-  unsigned RHSOffset = this->allocateLocalPrimitive(RHS, PT_Ptr, true, false);
+  PrimType ResultElemT = this->classifyComplexElementType(E->getType());
   unsigned ResultOffset = ~0u;
-  if (!this->DiscardResult)
+  if (!DiscardResult)
     ResultOffset = this->allocateLocalPrimitive(E, PT_Ptr, true, false);
 
-  assert(LHSElemT == RHSElemT);
-
   // Save result pointer in ResultOffset
   if (!this->DiscardResult) {
     if (!this->emitDupPtr(E))
@@ -682,16 +679,64 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
   }
 
   // Evaluate LHS and save value to LHSOffset.
-  if (!this->visit(LHS))
-    return false;
-  if (!this->emitSetLocal(PT_Ptr, LHSOffset, E))
-    return false;
+  bool LHSIsComplex;
+  unsigned LHSOffset;
+  if (LHS->getType()->isAnyComplexType()) {
+    LHSIsComplex = true;
+    LHSOffset = this->allocateLocalPrimitive(LHS, PT_Ptr, true, false);
+    if (!this->visit(LHS))
+      return false;
+    if (!this->emitSetLocal(PT_Ptr, LHSOffset, E))
+      return false;
+  } else {
+    LHSIsComplex = false;
+    PrimType LHST = classifyPrim(LHS->getType());
+    LHSOffset = this->allocateLocalPrimitive(LHS, LHST, true, false);
+    if (!this->visit(LHS))
+      return false;
+    if (!this->emitSetLocal(LHST, LHSOffset, E))
+      return false;
+  }
 
   // Same with RHS.
-  if (!this->visit(RHS))
-    return false;
-  if (!this->emitSetLocal(PT_Ptr, RHSOffset, E))
-    return false;
+  bool RHSIsComplex;
+  unsigned RHSOffset;
+  if (RHS->getType()->isAnyComplexType()) {
+    RHSIsComplex = true;
+    RHSOffset = this->allocateLocalPrimitive(RHS, PT_Ptr, true, false);
+    if (!this->visit(RHS))
+      return false;
+    if (!this->emitSetLocal(PT_Ptr, RHSOffset, E))
+      return false;
+  } else {
+    RHSIsComplex = false;
+    PrimType RHST = classifyPrim(RHS->getType());
+    RHSOffset = this->allocateLocalPrimitive(RHS, RHST, true, false);
+    if (!this->visit(RHS))
+      return false;
+    if (!this->emitSetLocal(RHST, RHSOffset, E))
+      return false;
+  }
+
+  // For both LHS and RHS, either load the value from the complex pointer, or
+  // directly from the local variable. For index 1 (i.e. the imaginary part),
+  // just load 0 and do the operation anyway.
+  auto loadComplexValue = [this](bool IsComplex, unsigned ElemIndex,
+                                 unsigned Offset, const Expr *E) -> bool {
+    if (IsComplex) {
+      if (!this->emitGetLocal(PT_Ptr, Offset, E))
+        return false;
+      if (!this->emitConstUint8(ElemIndex, E))
+        return false;
+      if (!this->emitArrayElemPtrPopUint8(E))
+        return false;
+      return this->emitLoadPop(classifyComplexElementType(E->getType()), E);
+    }
+    if (ElemIndex == 0)
+      return this->emitGetLocal(classifyPrim(E->getType()), Offset, E);
+    return this->visitZeroInitializer(classifyPrim(E->getType()), E->getType(),
+                                      E);
+  };
 
   // Now we can get pointers to the LHS and RHS from the offsets above.
   BinaryOperatorKind Op = E->getOpcode();
@@ -702,41 +747,29 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
         return false;
     }
 
-    if (!this->emitGetLocal(PT_Ptr, LHSOffset, E))
-      return false;
-    if (!this->emitConstUint8(ElemIndex, E))
-      return false;
-    if (!this->emitArrayElemPtrPopUint8(E))
-      return false;
-    if (!this->emitLoadPop(LHSElemT, E))
+    if (!loadComplexValue(LHSIsComplex, ElemIndex, LHSOffset, LHS))
       return false;
 
-    if (!this->emitGetLocal(PT_Ptr, RHSOffset, E))
-      return false;
-    if (!this->emitConstUint8(ElemIndex, E))
-      return false;
-    if (!this->emitArrayElemPtrPopUint8(E))
-      return false;
-    if (!this->emitLoadPop(RHSElemT, E))
+    if (!loadComplexValue(RHSIsComplex, ElemIndex, RHSOffset, RHS))
       return false;
 
     // The actual operation.
     switch (Op) {
     case BO_Add:
-      if (LHSElemT == PT_Float) {
+      if (ResultElemT == PT_Float) {
         if (!this->emitAddf(getRoundingMode(E), E))
           return false;
       } else {
-        if (!this->emitAdd(LHSElemT, E))
+        if (!this->emitAdd(ResultElemT, E))
           return false;
       }
       break;
     case BO_Sub:
-      if (LHSElemT == PT_Float) {
+      if (ResultElemT == PT_Float) {
         if (!this->emitSubf(getRoundingMode(E), E))
           return false;
       } else {
-        if (!this->emitSub(LHSElemT, E))
+        if (!this->emitSub(ResultElemT, E))
           return false;
       }
       break;
@@ -747,10 +780,10 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
 
     if (!this->DiscardResult) {
       // Initialize array element with the value we just computed.
-      if (!this->emitInitElemPop(LHSElemT, ElemIndex, E))
+      if (!this->emitInitElemPop(ResultElemT, ElemIndex, E))
         return false;
     } else {
-      if (!this->emitPop(LHSElemT, E))
+      if (!this->emitPop(ResultElemT, E))
         return false;
     }
   }

diff  --git a/clang/test/AST/Interp/complex.cpp b/clang/test/AST/Interp/complex.cpp
index 612a94b31171b9..2b65ccf9946e70 100644
--- a/clang/test/AST/Interp/complex.cpp
+++ b/clang/test/AST/Interp/complex.cpp
@@ -164,6 +164,20 @@ namespace Add {
   constexpr _Complex unsigned int I3 = I1 + I2;
   static_assert(__real(I3) == 45, "");
   static_assert(__imag(I3) == 12, "");
+
+  static_assert(__real(A + 2.0) == 15, "");
+  static_assert(__imag(A + 2.0) == 2, "");
+  static_assert(__real(2.0 + A) == 15, "");
+  static_assert(__imag(2.0 + A) == 2, "");
+
+  static_assert(__real(D + 1) == 16, "");
+  static_assert(__real(D + 1.0) == 16, "");
+  constexpr _Complex double D2 = D + 3.0;
+  static_assert(__real(D2) == 18.0, "");
+  static_assert(__imag(D2) == 3.0, "");
+  constexpr _Complex double D3 = 3.0 + D;
+  static_assert(__real(D3) == 18.0, "");
+  static_assert(__imag(D3) == 3.0, "");
 }
 
 namespace Sub {
@@ -172,6 +186,8 @@ namespace Sub {
   constexpr _Complex float C = A - B;
   static_assert(__real(C) == 11.0, "");
   static_assert(__imag(C) == 1.0, "");
+  static_assert(__real(A - 2.0) == 11, "");
+  static_assert(__real(2.0 - A) == -11, "");
 
   constexpr _Complex float D = B - A;
   static_assert(__real(D) == -11.0, "");
@@ -189,6 +205,15 @@ namespace Sub {
   constexpr _Complex float D_ = A_ - B_;
   static_assert(__real(D_) == 11.0, "");
   static_assert(__imag(D_) == 1.0, "");
+
+  static_assert(__real(D - 1) == -12, "");
+  static_assert(__real(D - 1.0) == -12, "");
+  constexpr _Complex double D2 = D - 3.0;
+  static_assert(__real(D2) == -14.0, "");
+  static_assert(__imag(D2) == -1.0, "");
+  constexpr _Complex double D3 = 3.0 - D;
+  static_assert(__real(D3) == 14.0, "");
+  static_assert(__imag(D3) == 1.0, "");
 }
 
 }


        


More information about the cfe-commits mailing list