[Mlir-commits] [mlir] [mlir][complex] Fix exp accuracy (PR #164952)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 24 03:08:42 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Aleksei Nurmukhametov (nurmukhametov)

<details>
<summary>Changes</summary>

This ports openxla/stablehlo/#<!-- -->2682 implementation by @<!-- -->pearu.

Three tests were added to `Integration/Dialect/Complex/CPU/correctness.mlir`. I also verified accuracy using XLA's complex_unary_op_test and its MLIR emitters.

---
Full diff: https://github.com/llvm/llvm-project/pull/164952.diff


3 Files Affected:

- (modified) mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp (+41-13) 
- (modified) mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir (+34-6) 
- (modified) mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir (+32) 


``````````diff
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 0fe72394b61d6..9e46b7d78baca 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -313,25 +313,53 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
 struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
   using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
 
+  // exp(x+I*y) = exp(x)*(cos(y)+I*sin(y))
+  // Handle special cases as StableHLO implementation does:
+  // 1. When b == 0, set imag(exp(z)) = 0
+  // 2. When exp(x) == inf, use exp(x/2)*(cos(y)+I*sin(y))*exp(x/2)
   LogicalResult
   matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = op.getLoc();
     auto type = cast<ComplexType>(adaptor.getComplex().getType());
-    auto elementType = cast<FloatType>(type.getElementType());
-    arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
-
-    Value real =
-        complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
-    Value imag =
-        complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
-    Value expReal = math::ExpOp::create(rewriter, loc, real, fmf.getValue());
-    Value cosImag = math::CosOp::create(rewriter, loc, imag, fmf.getValue());
+    auto ET = cast<FloatType>(type.getElementType());
+    arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
+    const auto &floatSemantics = ET.getFloatSemantics();
+    ImplicitLocOpBuilder b(loc, rewriter);
+
+    Value x = complex::ReOp::create(b, ET, adaptor.getComplex());
+    Value y = complex::ImOp::create(b, ET, adaptor.getComplex());
+    Value zero = arith::ConstantOp::create(b, ET, b.getZeroAttr(ET));
+    Value half = arith::ConstantOp::create(b, ET, b.getFloatAttr(ET, 0.5));
+    Value inf = arith::ConstantOp::create(
+        b, ET, b.getFloatAttr(ET, APFloat::getInf(floatSemantics)));
+
+    Value exp = math::ExpOp::create(b, x, fmf);
+    Value xHalf = arith::MulFOp::create(b, x, half, fmf);
+    Value expHalf = math::ExpOp::create(b, xHalf, fmf);
+    Value cos = math::CosOp::create(b, y, fmf);
+    Value sin = math::SinOp::create(b, y, fmf);
+
+    Value expIsInf =
+        arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, exp, inf, fmf);
+    Value yIsZero =
+        arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, y, zero);
+
+    // Real path: select between exp(x)*cos(y) and exp(x/2)*cos(y)*exp(x/2)
+    Value realNormal = arith::MulFOp::create(b, exp, cos, fmf);
+    Value expHalfCos = arith::MulFOp::create(b, expHalf, cos, fmf);
+    Value realOverflow = arith::MulFOp::create(b, expHalfCos, expHalf, fmf);
     Value resultReal =
-        arith::MulFOp::create(rewriter, loc, expReal, cosImag, fmf.getValue());
-    Value sinImag = math::SinOp::create(rewriter, loc, imag, fmf.getValue());
-    Value resultImag =
-        arith::MulFOp::create(rewriter, loc, expReal, sinImag, fmf.getValue());
+        arith::SelectOp::create(b, expIsInf, realOverflow, realNormal);
+
+    // Imaginary part: if y == 0 return 0 else select between exp(x)*sin(y) and
+    // exp(x/2)*sin(y)*exp(x/2)
+    Value imagNormal = arith::MulFOp::create(b, exp, sin, fmf);
+    Value expHalfSin = arith::MulFOp::create(b, expHalf, sin, fmf);
+    Value imagOverflow = arith::MulFOp::create(b, expHalfSin, expHalf, fmf);
+    Value imagNonZero =
+        arith::SelectOp::create(b, expIsInf, imagOverflow, imagNormal);
+    Value resultImag = arith::SelectOp::create(b, yIsZero, zero, imagNonZero);
 
     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
                                                    resultImag);
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index dec62f92c7b2e..7a82236b0656e 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -211,11 +211,25 @@ func.func @complex_exp(%arg: complex<f32>) -> complex<f32> {
 }
 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK-DAG: %[[COS_IMAG:.*]] = math.cos %[[IMAG]] : f32
+// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32
+// CHECK-DAG: %[[INF:.*]] = arith.constant 0x7F800000 : f32
 // CHECK-DAG: %[[EXP_REAL:.*]] = math.exp %[[REAL]] : f32
-// CHECK-DAG: %[[RESULT_REAL:.]] = arith.mulf %[[EXP_REAL]], %[[COS_IMAG]] : f32
+// CHECK-DAG: %[[REAL_HALF:.*]] = arith.mulf %[[REAL]], %[[HALF]] : f32
+// CHECK-DAG: %[[EXP_HALF:.*]] = math.exp %[[REAL_HALF]] : f32
+// CHECK-DAG: %[[COS_IMAG:.*]] = math.cos %[[IMAG]] : f32
 // CHECK-DAG: %[[SIN_IMAG:.*]] = math.sin %[[IMAG]] : f32
-// CHECK-DAG: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_REAL]], %[[SIN_IMAG]] : f32
+// CHECK-DAG: %[[IS_INF:.*]] = arith.cmpf oeq, %[[EXP_REAL]], %[[INF]] : f32
+// CHECK-DAG: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
+// CHECK-DAG: %[[REAL_NORMAL:.*]] = arith.mulf %[[EXP_REAL]], %[[COS_IMAG]] : f32
+// CHECK-DAG: %[[EXP_HALF_COS:.*]] = arith.mulf %[[EXP_HALF]], %[[COS_IMAG]] : f32
+// CHECK-DAG: %[[REAL_OVERFLOW:.*]] = arith.mulf %[[EXP_HALF_COS]], %[[EXP_HALF]] : f32
+// CHECK: %[[RESULT_REAL:.*]] = arith.select %[[IS_INF]], %[[REAL_OVERFLOW]], %[[REAL_NORMAL]] : f32
+// CHECK-DAG: %[[IMAG_NORMAL:.*]] = arith.mulf %[[EXP_REAL]], %[[SIN_IMAG]] : f32
+// CHECK-DAG: %[[EXP_HALF_SIN:.*]] = arith.mulf %[[EXP_HALF]], %[[SIN_IMAG]] : f32
+// CHECK-DAG: %[[IMAG_OVERFLOW:.*]] = arith.mulf %[[EXP_HALF_SIN]], %[[EXP_HALF]] : f32
+// CHECK-DAG: %[[IMAG_NONZERO:.*]] = arith.select %[[IS_INF]], %[[IMAG_OVERFLOW]], %[[IMAG_NORMAL]] : f32
+// CHECK: %[[RESULT_IMAG:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[ZERO]], %[[IMAG_NONZERO]] : f32
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
@@ -832,11 +846,25 @@ func.func @complex_exp_with_fmf(%arg: complex<f32>) -> complex<f32> {
 }
 // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
 // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK-DAG: %[[COS_IMAG:.*]] = math.cos %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32
+// CHECK-DAG: %[[INF:.*]] = arith.constant 0x7F800000 : f32
 // CHECK-DAG: %[[EXP_REAL:.*]] = math.exp %[[REAL]] fastmath<nnan,contract> : f32
-// CHECK-DAG: %[[RESULT_REAL:.]] = arith.mulf %[[EXP_REAL]], %[[COS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[REAL_HALF:.*]] = arith.mulf %[[REAL]], %[[HALF]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[EXP_HALF:.*]] = math.exp %[[REAL_HALF]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[COS_IMAG:.*]] = math.cos %[[IMAG]] fastmath<nnan,contract> : f32
 // CHECK-DAG: %[[SIN_IMAG:.*]] = math.sin %[[IMAG]] fastmath<nnan,contract> : f32
-// CHECK-DAG: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_REAL]], %[[SIN_IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[IS_INF:.*]] = arith.cmpf oeq, %[[EXP_REAL]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
+// CHECK-DAG: %[[REAL_NORMAL:.*]] = arith.mulf %[[EXP_REAL]], %[[COS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[EXP_HALF_COS:.*]] = arith.mulf %[[EXP_HALF]], %[[COS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[REAL_OVERFLOW:.*]] = arith.mulf %[[EXP_HALF_COS]], %[[EXP_HALF]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_REAL:.*]] = arith.select %[[IS_INF]], %[[REAL_OVERFLOW]], %[[REAL_NORMAL]] : f32
+// CHECK-DAG: %[[IMAG_NORMAL:.*]] = arith.mulf %[[EXP_REAL]], %[[SIN_IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[EXP_HALF_SIN:.*]] = arith.mulf %[[EXP_HALF]], %[[SIN_IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[IMAG_OVERFLOW:.*]] = arith.mulf %[[EXP_HALF_SIN]], %[[EXP_HALF]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[IMAG_NONZERO:.*]] = arith.select %[[IS_INF]], %[[IMAG_OVERFLOW]], %[[IMAG_NORMAL]] : f32
+// CHECK: %[[RESULT_IMAG:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[ZERO]], %[[IMAG_NONZERO]] : f32
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
diff --git a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
index 1bcef0a0df316..ea587e92674d7 100644
--- a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
+++ b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
@@ -49,6 +49,11 @@ func.func @conj(%arg: complex<f32>) -> complex<f32> {
   func.return %conj : complex<f32>
 }
 
+func.func @exp(%arg: complex<f32>) -> complex<f32> {
+  %exp = complex.exp %arg : complex<f32>
+  func.return %exp : complex<f32>
+}
+
 // %input contains pairs of lhs, rhs, i.e. [lhs_0, rhs_0, lhs_1, rhs_1,...]
 func.func @test_binary(%input: tensor<?xcomplex<f32>>,
                        %func: (complex<f32>, complex<f32>) -> complex<f32>) {
@@ -353,5 +358,32 @@ func.func @entry() {
   call @test_element_f64(%abs_test_cast, %abs_func)
     : (tensor<?xcomplex<f64>>, (complex<f64>) -> f64) -> ()
 
+  // complex.exp test
+  %exp_test = arith.constant dense<[
+    (1.0, 2.0),
+    // CHECK:      -1.1312
+    // CHECK-NEXT:  2.4717
+
+    // The first case to consider is overflow of exp(real_part). If computed
+    // directly, this yields inf * 0 = NaN, which is incorrect.
+    (500.0, 0.0),
+    // CHECK-NEXT:  inf
+    // CHECK-NOT:   nan
+    // CHECK-NEXT:  0
+
+    // In this case, the overflow of exp(real_part) is compensated when
+    // sin(imag_part) is close to zero, yielding a finite imaginary part.
+    (90.0238094, 5.900613e-39)
+    // CHECK-NEXT:  inf
+    // CHECK-NOT:   inf
+    // CHECK-NEXT:  7.3746
+  ]> : tensor<3xcomplex<f32>>
+  %exp_test_cast = tensor.cast %exp_test
+    :  tensor<3xcomplex<f32>> to tensor<?xcomplex<f32>>
+
+  %exp_func = func.constant @exp : (complex<f32>) -> complex<f32>
+  call @test_unary(%exp_test_cast, %exp_func)
+    : (tensor<?xcomplex<f32>>, (complex<f32>) -> complex<f32>) -> ()
+
   func.return
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/164952


More information about the Mlir-commits mailing list