[clang] 4d7d45e - [clang][Interp] Implement complex division (#94892)

via cfe-commits cfe-commits at lists.llvm.org
Tue Jun 18 04:49:07 PDT 2024


Author: Timm Baeder
Date: 2024-06-18T13:49:02+02:00
New Revision: 4d7d45e8bab102f8792d45e1807cfe2d1c55dd69

URL: https://github.com/llvm/llvm-project/commit/4d7d45e8bab102f8792d45e1807cfe2d1c55dd69
DIFF: https://github.com/llvm/llvm-project/commit/4d7d45e8bab102f8792d45e1807cfe2d1c55dd69.diff

LOG: [clang][Interp] Implement complex division (#94892)

Share the implementation with the current interpreter.

Added: 
    

Modified: 
    clang/lib/AST/ExprConstShared.h
    clang/lib/AST/ExprConstant.cpp
    clang/lib/AST/Interp/ByteCodeExprGen.cpp
    clang/lib/AST/Interp/Interp.h
    clang/lib/AST/Interp/Opcodes.td
    clang/test/AST/Interp/complex.cpp
    clang/test/SemaCXX/complex-folding.cpp

Removed: 
    


################################################################################
diff  --git a/clang/lib/AST/ExprConstShared.h b/clang/lib/AST/ExprConstShared.h
index 9decd47e41767..2a7088e4e371a 100644
--- a/clang/lib/AST/ExprConstShared.h
+++ b/clang/lib/AST/ExprConstShared.h
@@ -62,5 +62,8 @@ GCCTypeClass EvaluateBuiltinClassifyType(QualType T,
 void HandleComplexComplexMul(llvm::APFloat A, llvm::APFloat B, llvm::APFloat C,
                              llvm::APFloat D, llvm::APFloat &ResR,
                              llvm::APFloat &ResI);
+void HandleComplexComplexDiv(llvm::APFloat A, llvm::APFloat B, llvm::APFloat C,
+                             llvm::APFloat D, llvm::APFloat &ResR,
+                             llvm::APFloat &ResI);
 
 #endif

diff  --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 3a6c8b4f82cca..9ac53438de8e4 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -15189,6 +15189,48 @@ void HandleComplexComplexMul(APFloat A, APFloat B, APFloat C, APFloat D,
   }
 }
 
+void HandleComplexComplexDiv(APFloat A, APFloat B, APFloat C, APFloat D,
+                             APFloat &ResR, APFloat &ResI) {
+  // This is an implementation of complex division according to the
+  // constraints laid out in C11 Annex G. The implementation uses the
+  // following naming scheme:
+  //   (a + ib) / (c + id)
+
+  int DenomLogB = 0;
+  APFloat MaxCD = maxnum(abs(C), abs(D));
+  if (MaxCD.isFinite()) {
+    DenomLogB = ilogb(MaxCD);
+    C = scalbn(C, -DenomLogB, APFloat::rmNearestTiesToEven);
+    D = scalbn(D, -DenomLogB, APFloat::rmNearestTiesToEven);
+  }
+  APFloat Denom = C * C + D * D;
+  ResR =
+      scalbn((A * C + B * D) / Denom, -DenomLogB, APFloat::rmNearestTiesToEven);
+  ResI =
+      scalbn((B * C - A * D) / Denom, -DenomLogB, APFloat::rmNearestTiesToEven);
+  if (ResR.isNaN() && ResI.isNaN()) {
+    if (Denom.isPosZero() && (!A.isNaN() || !B.isNaN())) {
+      ResR = APFloat::getInf(ResR.getSemantics(), C.isNegative()) * A;
+      ResI = APFloat::getInf(ResR.getSemantics(), C.isNegative()) * B;
+    } else if ((A.isInfinity() || B.isInfinity()) && C.isFinite() &&
+               D.isFinite()) {
+      A = APFloat::copySign(APFloat(A.getSemantics(), A.isInfinity() ? 1 : 0),
+                            A);
+      B = APFloat::copySign(APFloat(B.getSemantics(), B.isInfinity() ? 1 : 0),
+                            B);
+      ResR = APFloat::getInf(ResR.getSemantics()) * (A * C + B * D);
+      ResI = APFloat::getInf(ResI.getSemantics()) * (B * C - A * D);
+    } else if (MaxCD.isInfinity() && A.isFinite() && B.isFinite()) {
+      C = APFloat::copySign(APFloat(C.getSemantics(), C.isInfinity() ? 1 : 0),
+                            C);
+      D = APFloat::copySign(APFloat(D.getSemantics(), D.isInfinity() ? 1 : 0),
+                            D);
+      ResR = APFloat::getZero(ResR.getSemantics()) * (A * C + B * D);
+      ResI = APFloat::getZero(ResI.getSemantics()) * (B * C - A * D);
+    }
+  }
+}
+
 bool ComplexExprEvaluator::VisitBinaryOperator(const BinaryOperator *E) {
   if (E->isPtrMemOp() || E->isAssignmentOp() || E->getOpcode() == BO_Comma)
     return ExprEvaluatorBaseTy::VisitBinaryOperator(E);
@@ -15326,39 +15368,7 @@ bool ComplexExprEvaluator::VisitBinaryOperator(const BinaryOperator *E) {
           // No real optimizations we can do here, stub out with zero.
           B = APFloat::getZero(A.getSemantics());
         }
-        int DenomLogB = 0;
-        APFloat MaxCD = maxnum(abs(C), abs(D));
-        if (MaxCD.isFinite()) {
-          DenomLogB = ilogb(MaxCD);
-          C = scalbn(C, -DenomLogB, APFloat::rmNearestTiesToEven);
-          D = scalbn(D, -DenomLogB, APFloat::rmNearestTiesToEven);
-        }
-        APFloat Denom = C * C + D * D;
-        ResR = scalbn((A * C + B * D) / Denom, -DenomLogB,
-                      APFloat::rmNearestTiesToEven);
-        ResI = scalbn((B * C - A * D) / Denom, -DenomLogB,
-                      APFloat::rmNearestTiesToEven);
-        if (ResR.isNaN() && ResI.isNaN()) {
-          if (Denom.isPosZero() && (!A.isNaN() || !B.isNaN())) {
-            ResR = APFloat::getInf(ResR.getSemantics(), C.isNegative()) * A;
-            ResI = APFloat::getInf(ResR.getSemantics(), C.isNegative()) * B;
-          } else if ((A.isInfinity() || B.isInfinity()) && C.isFinite() &&
-                     D.isFinite()) {
-            A = APFloat::copySign(
-                APFloat(A.getSemantics(), A.isInfinity() ? 1 : 0), A);
-            B = APFloat::copySign(
-                APFloat(B.getSemantics(), B.isInfinity() ? 1 : 0), B);
-            ResR = APFloat::getInf(ResR.getSemantics()) * (A * C + B * D);
-            ResI = APFloat::getInf(ResI.getSemantics()) * (B * C - A * D);
-          } else if (MaxCD.isInfinity() && A.isFinite() && B.isFinite()) {
-            C = APFloat::copySign(
-                APFloat(C.getSemantics(), C.isInfinity() ? 1 : 0), C);
-            D = APFloat::copySign(
-                APFloat(D.getSemantics(), D.isInfinity() ? 1 : 0), D);
-            ResR = APFloat::getZero(ResR.getSemantics()) * (A * C + B * D);
-            ResI = APFloat::getZero(ResI.getSemantics()) * (B * C - A * D);
-          }
-        }
+        HandleComplexComplexDiv(A, B, C, D, ResR, ResI);
       }
     } else {
       if (RHS.getComplexIntReal() == 0 && RHS.getComplexIntImag() == 0)

diff  --git a/clang/lib/AST/Interp/ByteCodeExprGen.cpp b/clang/lib/AST/Interp/ByteCodeExprGen.cpp
index 0964ac046e630..d47f6f23c8bc1 100644
--- a/clang/lib/AST/Interp/ByteCodeExprGen.cpp
+++ b/clang/lib/AST/Interp/ByteCodeExprGen.cpp
@@ -891,11 +891,14 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
   if (const auto *AT = RHSType->getAs<AtomicType>())
     RHSType = AT->getValueType();
 
+  bool LHSIsComplex = LHSType->isAnyComplexType();
+  unsigned LHSOffset;
+  bool RHSIsComplex = RHSType->isAnyComplexType();
+
   // For ComplexComplex Mul, we have special ops to make their implementation
   // easier.
   BinaryOperatorKind Op = E->getOpcode();
-  if (Op == BO_Mul && LHSType->isAnyComplexType() &&
-      RHSType->isAnyComplexType()) {
+  if (Op == BO_Mul && LHSIsComplex && RHSIsComplex) {
     assert(classifyPrim(LHSType->getAs<ComplexType>()->getElementType()) ==
            classifyPrim(RHSType->getAs<ComplexType>()->getElementType()));
     PrimType ElemT =
@@ -907,18 +910,51 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
     return this->emitMulc(ElemT, E);
   }
 
+  if (Op == BO_Div && RHSIsComplex) {
+    QualType ElemQT = RHSType->getAs<ComplexType>()->getElementType();
+    PrimType ElemT = classifyPrim(ElemQT);
+    // If the LHS is not complex, we still need to do the full complex
+    // division, so just stub create a complex value and stub it out with
+    // the LHS and a zero.
+
+    if (!LHSIsComplex) {
+      // This is using the RHS type for the fake-complex LHS.
+      if (auto LHSO = allocateLocal(RHS))
+        LHSOffset = *LHSO;
+      else
+        return false;
+
+      if (!this->emitGetPtrLocal(LHSOffset, E))
+        return false;
+
+      if (!this->visit(LHS))
+        return false;
+      // real is LHS
+      if (!this->emitInitElem(ElemT, 0, E))
+        return false;
+      // imag is zero
+      if (!this->visitZeroInitializer(ElemT, ElemQT, E))
+        return false;
+      if (!this->emitInitElem(ElemT, 1, E))
+        return false;
+    } else {
+      if (!this->visit(LHS))
+        return false;
+    }
+
+    if (!this->visit(RHS))
+      return false;
+    return this->emitDivc(ElemT, E);
+  }
+
   // Evaluate LHS and save value to LHSOffset.
-  bool LHSIsComplex;
-  unsigned LHSOffset;
   if (LHSType->isAnyComplexType()) {
-    LHSIsComplex = true;
     LHSOffset = this->allocateLocalPrimitive(LHS, PT_Ptr, true, false);
     if (!this->visit(LHS))
       return false;
     if (!this->emitSetLocal(PT_Ptr, LHSOffset, E))
       return false;
   } else {
-    LHSIsComplex = false;
     PrimType LHST = classifyPrim(LHSType);
     LHSOffset = this->allocateLocalPrimitive(LHS, LHST, true, false);
     if (!this->visit(LHS))
@@ -928,17 +964,14 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
   }
 
   // Same with RHS.
-  bool RHSIsComplex;
   unsigned RHSOffset;
   if (RHSType->isAnyComplexType()) {
-    RHSIsComplex = true;
     RHSOffset = this->allocateLocalPrimitive(RHS, PT_Ptr, true, false);
     if (!this->visit(RHS))
       return false;
     if (!this->emitSetLocal(PT_Ptr, RHSOffset, E))
       return false;
   } else {
-    RHSIsComplex = false;
     PrimType RHST = classifyPrim(RHSType);
     RHSOffset = this->allocateLocalPrimitive(RHS, RHST, true, false);
     if (!this->visit(RHS))
@@ -1018,6 +1051,22 @@ bool ByteCodeExprGen<Emitter>::VisitComplexBinOp(const BinaryOperator *E) {
           return false;
       }
       break;
+    case BO_Div:
+      assert(!RHSIsComplex);
+      if (!loadComplexValue(LHSIsComplex, false, ElemIndex, LHSOffset, LHS))
+        return false;
+
+      if (!loadComplexValue(RHSIsComplex, false, ElemIndex, RHSOffset, RHS))
+        return false;
+
+      if (ResultElemT == PT_Float) {
+        if (!this->emitDivf(getRoundingMode(E), E))
+          return false;
+      } else {
+        if (!this->emitDiv(ResultElemT, E))
+          return false;
+      }
+      break;
 
     default:
       return false;

diff  --git a/clang/lib/AST/Interp/Interp.h b/clang/lib/AST/Interp/Interp.h
index 44ff50b26133a..fea83de829261 100644
--- a/clang/lib/AST/Interp/Interp.h
+++ b/clang/lib/AST/Interp/Interp.h
@@ -425,6 +425,78 @@ inline bool Mulc(InterpState &S, CodePtr OpPC) {
   return true;
 }
 
+template <PrimType Name, class T = typename PrimConv<Name>::T>
+inline bool Divc(InterpState &S, CodePtr OpPC) {
+  const Pointer &RHS = S.Stk.pop<Pointer>();
+  const Pointer &LHS = S.Stk.pop<Pointer>();
+  const Pointer &Result = S.Stk.peek<Pointer>();
+
+  if constexpr (std::is_same_v<T, Floating>) {
+    APFloat A = LHS.atIndex(0).deref<Floating>().getAPFloat();
+    APFloat B = LHS.atIndex(1).deref<Floating>().getAPFloat();
+    APFloat C = RHS.atIndex(0).deref<Floating>().getAPFloat();
+    APFloat D = RHS.atIndex(1).deref<Floating>().getAPFloat();
+
+    APFloat ResR(A.getSemantics());
+    APFloat ResI(A.getSemantics());
+    HandleComplexComplexDiv(A, B, C, D, ResR, ResI);
+
+    // Copy into the result.
+    Result.atIndex(0).deref<Floating>() = Floating(ResR);
+    Result.atIndex(0).initialize();
+    Result.atIndex(1).deref<Floating>() = Floating(ResI);
+    Result.atIndex(1).initialize();
+    Result.initialize();
+  } else {
+    // Integer element type.
+    const T &LHSR = LHS.atIndex(0).deref<T>();
+    const T &LHSI = LHS.atIndex(1).deref<T>();
+    const T &RHSR = RHS.atIndex(0).deref<T>();
+    const T &RHSI = RHS.atIndex(1).deref<T>();
+    unsigned Bits = LHSR.bitWidth();
+    const T Zero = T::from(0, Bits);
+
+    if (Compare(RHSR, Zero) == ComparisonCategoryResult::Equal &&
+        Compare(RHSI, Zero) == ComparisonCategoryResult::Equal) {
+      const SourceInfo &E = S.Current->getSource(OpPC);
+      S.FFDiag(E, diag::note_expr_divide_by_zero);
+      return false;
+    }
+
+    // Den = real(RHS)² + imag(RHS)²
+    T A, B;
+    if (T::mul(RHSR, RHSR, Bits, &A) || T::mul(RHSI, RHSI, Bits, &B))
+      return false;
+    T Den;
+    if (T::add(A, B, Bits, &Den))
+      return false;
+
+    // real(Result) = ((real(LHS) * real(RHS)) + (imag(LHS) * imag(RHS))) / Den
+    T &ResultR = Result.atIndex(0).deref<T>();
+    T &ResultI = Result.atIndex(1).deref<T>();
+
+    if (T::mul(LHSR, RHSR, Bits, &A) || T::mul(LHSI, RHSI, Bits, &B))
+      return false;
+    if (T::add(A, B, Bits, &ResultR))
+      return false;
+    if (T::div(ResultR, Den, Bits, &ResultR))
+      return false;
+    Result.atIndex(0).initialize();
+
+    // imag(Result) = ((imag(LHS) * real(RHS)) - (real(LHS) * imag(RHS))) / Den
+    if (T::mul(LHSI, RHSR, Bits, &A) || T::mul(LHSR, RHSI, Bits, &B))
+      return false;
+    if (T::sub(A, B, Bits, &ResultI))
+      return false;
+    if (T::div(ResultI, Den, Bits, &ResultI))
+      return false;
+    Result.atIndex(1).initialize();
+    Result.initialize();
+  }
+
+  return true;
+}
+
 /// 1) Pops the RHS from the stack.
 /// 2) Pops the LHS from the stack.
 /// 3) Pushes 'LHS & RHS' on the stack

diff  --git a/clang/lib/AST/Interp/Opcodes.td b/clang/lib/AST/Interp/Opcodes.td
index aa07858fdb3c2..ddd955fc4cfa4 100644
--- a/clang/lib/AST/Interp/Opcodes.td
+++ b/clang/lib/AST/Interp/Opcodes.td
@@ -533,6 +533,10 @@ def Mulc : Opcode {
 def Rem  : IntegerOpcode;
 def Div  : IntegerOpcode;
 def Divf : FloatOpcode;
+def Divc : Opcode {
+  let Types = [NumberTypeClass];
+  let HasGroup = 1;
+}
 
 def BitAnd : IntegerOpcode;
 def BitOr : IntegerOpcode;

diff  --git a/clang/test/AST/Interp/complex.cpp b/clang/test/AST/Interp/complex.cpp
index 7c72087585b27..a969aadfdcd08 100644
--- a/clang/test/AST/Interp/complex.cpp
+++ b/clang/test/AST/Interp/complex.cpp
@@ -40,6 +40,21 @@ constexpr _Complex int IIMC = IIMA * IIMB;
 static_assert(__real(IIMC) == -30, "");
 static_assert(__imag(IIMC) == 40, "");
 
+static_assert(1.0j / 0.0 == 1); // both-error {{static assertion}} \
+                                // both-note {{division by zero}}
+static_assert(__builtin_isinf_sign(__real__((1.0 + 1.0j) / (0.0 + 0.0j))) == 1);
+static_assert(__builtin_isinf_sign(__real__((1.0 + 1.0j) / 0.0)) == 1); // both-error {{static assertion}} \
+                                                                        // both-note {{division by zero}}
+static_assert(__builtin_isinf_sign(__real__((__builtin_inf() + 1.0j) / (0.0 + 0.0j))) == 1);
+static_assert(__builtin_isinf_sign(__imag__((1.0 + InfC) / (0.0 + 0.0j))) == 1);
+static_assert(__builtin_isinf_sign(__imag__((InfInf) / (0.0 + 0.0j))) == 1);
+
+constexpr _Complex int IIDA = {10,20};
+constexpr _Complex int IIDB = {1,2};
+constexpr _Complex int IIDC = IIDA / IIDB;
+static_assert(__real(IIDC) == 10, "");
+static_assert(__imag(IIDC) == 0, "");
+
 constexpr _Complex int Comma1 = {1, 2};
 constexpr _Complex int Comma2 = (0, Comma1);
 static_assert(Comma1 == Comma1, "");

diff  --git a/clang/test/SemaCXX/complex-folding.cpp b/clang/test/SemaCXX/complex-folding.cpp
index 7bfd36f156ea6..f2fa93ad4795b 100644
--- a/clang/test/SemaCXX/complex-folding.cpp
+++ b/clang/test/SemaCXX/complex-folding.cpp
@@ -1,4 +1,5 @@
 // RUN: %clang_cc1 %s -std=c++1z -fsyntax-only -verify
+// RUN: %clang_cc1 %s -std=c++1z -fsyntax-only -verify -fexperimental-new-constant-interpreter
 //
 // Test the constant folding of builtin complex numbers.
 


        


More information about the cfe-commits mailing list