[Mlir-commits] [mlir] Fix complex power for large inputs. (PR #88387)

Johannes Reifferscheid llvmlistbot at llvm.org
Thu Apr 11 05:58:35 PDT 2024


https://github.com/jreiffers created https://github.com/llvm/llvm-project/pull/88387

For example, 1e30^1.2 currently overflows.

Also forward fastmath flags.

This ports XLA's logic and was verified with its test suite. Note that rsqrt and sqrt are still broken.

>From 146390a14313d313ce31b5e1389eede7bd6e6a1f Mon Sep 17 00:00:00 2001
From: Johannes Reifferscheid <jreiffers at google.com>
Date: Thu, 11 Apr 2024 14:47:19 +0200
Subject: [PATCH] Fix complex power for large inputs.

For example, 1e30^1.2 currently overflows.

Also forward fastmath flags.

This ports XLA's logic and was verified with its test suite.
---
 .../ComplexToStandard/ComplexToStandard.cpp   | 149 +++++++++++-------
 .../convert-to-standard.mlir                  |  21 ++-
 2 files changed, 115 insertions(+), 55 deletions(-)

diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 462036e51a1f1c..9c82e8105f06e5 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -989,65 +989,107 @@ struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> {
   }
 };
 
-/// Coverts x^y = (a+bi)^(c+di) to
+/// Converts lhs^y = (a+bi)^(c+di) to
 ///    (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
 ///    where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
 static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
-                                 ComplexType type, Value a, Value b, Value c,
-                                 Value d) {
+                                 ComplexType type, Value lhs, Value c, Value d,
+                                 arith::FastMathFlags fmf) {
   auto elementType = cast<FloatType>(type.getElementType());
 
-  // Compute (a*a+b*b)^(0.5c).
-  Value aaPbb = builder.create<arith::AddFOp>(
-      builder.create<arith::MulFOp>(a, a), builder.create<arith::MulFOp>(b, b));
-  Value half = builder.create<arith::ConstantOp>(
-      elementType, builder.getFloatAttr(elementType, 0.5));
-  Value halfC = builder.create<arith::MulFOp>(half, c);
-  Value aaPbbTohalfC = builder.create<math::PowFOp>(aaPbb, halfC);
-
-  // Compute exp(-d*atan2(b,a)).
-  Value negD = builder.create<arith::NegFOp>(d);
-  Value argX = builder.create<math::Atan2Op>(b, a);
-  Value negDArgX = builder.create<arith::MulFOp>(negD, argX);
-  Value eToNegDArgX = builder.create<math::ExpOp>(negDArgX);
-
-  // Compute (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)).
-  Value coeff = builder.create<arith::MulFOp>(aaPbbTohalfC, eToNegDArgX);
-
-  // Compute c*atan2(b,a)+0.5d*ln(a*a+b*b).
-  Value lnAaPbb = builder.create<math::LogOp>(aaPbb);
-  Value halfD = builder.create<arith::MulFOp>(half, d);
-  Value q = builder.create<arith::AddFOp>(
-      builder.create<arith::MulFOp>(c, argX),
-      builder.create<arith::MulFOp>(halfD, lnAaPbb));
-
-  Value cosQ = builder.create<math::CosOp>(q);
-  Value sinQ = builder.create<math::SinOp>(q);
+  Value a = builder.create<complex::ReOp>(lhs);
+  Value b = builder.create<complex::ImOp>(lhs);
+
+  Value abs = builder.create<complex::AbsOp>(lhs, fmf);
+  Value absToC = builder.create<math::PowFOp>(abs, c, fmf);
+
+  Value negD = builder.create<arith::NegFOp>(d, fmf);
+  Value argLhs = builder.create<math::Atan2Op>(b, a, fmf);
+  Value negDArgLhs = builder.create<arith::MulFOp>(negD, argLhs, fmf);
+  Value expNegDArgLhs = builder.create<math::ExpOp>(negDArgLhs, fmf);
+
+  Value coeff = builder.create<arith::MulFOp>(absToC, expNegDArgLhs, fmf);
+  Value lnAbs = builder.create<math::LogOp>(abs, fmf);
+  Value cArgLhs = builder.create<arith::MulFOp>(c, argLhs, fmf);
+  Value dLnAbs = builder.create<arith::MulFOp>(d, lnAbs, fmf);
+  Value q = builder.create<arith::AddFOp>(cArgLhs, dLnAbs, fmf);
+  Value cosQ = builder.create<math::CosOp>(q, fmf);
+  Value sinQ = builder.create<math::SinOp>(q, fmf);
+
+  Value inf = builder.create<arith::ConstantOp>(
+      elementType,
+      builder.getFloatAttr(elementType,
+                           APFloat::getInf(elementType.getFloatSemantics())));
   Value zero = builder.create<arith::ConstantOp>(
-      elementType, builder.getFloatAttr(elementType, 0));
+      elementType, builder.getFloatAttr(elementType, 0.0));
   Value one = builder.create<arith::ConstantOp>(
-      elementType, builder.getFloatAttr(elementType, 1));
-
-  Value xEqZero =
-      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, aaPbb, zero);
-  Value yGeZero = builder.create<arith::AndIOp>(
-      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, c, zero),
-      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero));
-  Value cEqZero =
-      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero);
-  Value complexZero = builder.create<complex::CreateOp>(type, zero, zero);
+      elementType, builder.getFloatAttr(elementType, 1.0));
   Value complexOne = builder.create<complex::CreateOp>(type, one, zero);
-  Value complexOther = builder.create<complex::CreateOp>(
-      type, builder.create<arith::MulFOp>(coeff, cosQ),
-      builder.create<arith::MulFOp>(coeff, sinQ));
+  Value complexZero = builder.create<complex::CreateOp>(type, zero, zero);
+  Value complexInf = builder.create<complex::CreateOp>(type, inf, zero);
 
-  // x^y is 0 if x is 0 and y > 0. 0^0 is defined to be 1.0, see
+  // Case 0:
+  // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see
+  // Branch Cuts for Complex Elementary Functions or Much Ado About
+  // Nothing's Sign Bit, W. Kahan, Section 10.
+  Value absEqZero =
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, abs, zero, fmf);
+  Value dEqZero =
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero, fmf);
+  Value cEqZero =
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero, fmf);
+  Value bEqZero =
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, b, zero, fmf);
+
+  Value zeroLeC =
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLE, zero, c, fmf);
+  Value coeffCosQ = builder.create<arith::MulFOp>(coeff, cosQ, fmf);
+  Value coeffSinQ = builder.create<arith::MulFOp>(coeff, sinQ, fmf);
+  Value complexOneOrZero =
+      builder.create<arith::SelectOp>(cEqZero, complexOne, complexZero);
+  Value coeffCosSin =
+      builder.create<complex::CreateOp>(type, coeffCosQ, coeffSinQ);
+  Value cutoff0 = builder.create<arith::SelectOp>(
+      builder.create<arith::AndIOp>(
+          builder.create<arith::AndIOp>(absEqZero, dEqZero), zeroLeC),
+      complexOneOrZero, coeffCosSin);
+
+  // Case 1:
+  // x^0 is defined to be 1 for any x, see
   // Branch Cuts for Complex Elementary Functions or Much Ado About
   // Nothing's Sign Bit, W. Kahan, Section 10.
-  return builder.create<arith::SelectOp>(
-      builder.create<arith::AndIOp>(xEqZero, yGeZero),
-      builder.create<arith::SelectOp>(cEqZero, complexOne, complexZero),
-      complexOther);
+  Value rhsEqZero = builder.create<arith::AndIOp>(cEqZero, dEqZero);
+  Value cutoff1 =
+      builder.create<arith::SelectOp>(rhsEqZero, complexOne, cutoff0);
+
+  // Case 2:
+  // 1^(c + d*i) = 1 + 0*i
+  Value lhsEqOne = builder.create<arith::AndIOp>(
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one),
+      bEqZero);
+  Value cutoff2 =
+      builder.create<arith::SelectOp>(lhsEqOne, complexOne, cutoff1);
+
+  // Case 3:
+  // inf^(c + 0*i) = inf + 0*i, c > 0
+  Value lhsEqInf = builder.create<arith::AndIOp>(
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf),
+      bEqZero);
+  Value rhsGt0 = builder.create<arith::AndIOp>(
+      dEqZero,
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero));
+  Value cutoff3 = builder.create<arith::SelectOp>(
+      builder.create<arith::AndIOp>(lhsEqInf, rhsGt0), complexInf, cutoff2);
+
+  // Case 4:
+  // inf^(c + 0*i) = 0 + 0*i, c < 0
+  Value rhsLt0 = builder.create<arith::AndIOp>(
+      dEqZero,
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero));
+  Value cutoff4 = builder.create<arith::SelectOp>(
+      builder.create<arith::AndIOp>(lhsEqInf, rhsLt0), complexZero, cutoff3);
+
+  return cutoff4;
 }
 
 struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
@@ -1060,12 +1102,11 @@ struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
     auto type = cast<ComplexType>(adaptor.getLhs().getType());
     auto elementType = cast<FloatType>(type.getElementType());
 
-    Value a = builder.create<complex::ReOp>(elementType, adaptor.getLhs());
-    Value b = builder.create<complex::ImOp>(elementType, adaptor.getLhs());
     Value c = builder.create<complex::ReOp>(elementType, adaptor.getRhs());
     Value d = builder.create<complex::ImOp>(elementType, adaptor.getRhs());
 
-    rewriter.replaceOp(op, {powOpConversionImpl(builder, type, a, b, c, d)});
+    rewriter.replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(),
+                                                c, d, op.getFastmath())});
     return success();
   }
 };
@@ -1080,14 +1121,14 @@ struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
     auto type = cast<ComplexType>(adaptor.getComplex().getType());
     auto elementType = cast<FloatType>(type.getElementType());
 
-    Value a = builder.create<complex::ReOp>(elementType, adaptor.getComplex());
-    Value b = builder.create<complex::ImOp>(elementType, adaptor.getComplex());
     Value c = builder.create<arith::ConstantOp>(
         elementType, builder.getFloatAttr(elementType, -0.5));
     Value d = builder.create<arith::ConstantOp>(
         elementType, builder.getFloatAttr(elementType, 0));
 
-    rewriter.replaceOp(op, {powOpConversionImpl(builder, type, a, b, c, d)});
+    rewriter.replaceOp(op,
+                       {powOpConversionImpl(builder, type, adaptor.getComplex(),
+                                            c, d, op.getFastmath())});
     return success();
   }
 };
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index a1de61d10bb226..8d2fb09daa87b6 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -753,13 +753,32 @@ func.func @complex_conj(%arg: complex<f32>) -> complex<f32> {
 
 // -----
 
-// CHECK-LABEL:   func.func @complex_pow
+// CHECK-LABEL: func.func @complex_pow
+// CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>
 func.func @complex_pow(%lhs: complex<f32>,
                        %rhs: complex<f32>) -> complex<f32> {
   %pow = complex.pow %lhs, %rhs : complex<f32>
   return %pow : complex<f32>
 }
 
+// CHECK: %[[A:.*]] = complex.re %[[LHS]]
+// CHECK: %[[B:.*]] = complex.im %[[LHS]]
+// CHECK: math.atan2 %[[B]], %[[A]] : f32
+
+// -----
+
+// CHECK-LABEL: func.func @complex_pow_with_fmf
+// CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>
+func.func @complex_pow_with_fmf(%lhs: complex<f32>,
+                                %rhs: complex<f32>) -> complex<f32> {
+  %pow = complex.pow %lhs, %rhs fastmath<nnan,contract> : complex<f32>
+  return %pow : complex<f32>
+}
+
+// CHECK: %[[A:.*]] = complex.re %[[LHS]]
+// CHECK: %[[B:.*]] = complex.im %[[LHS]]
+// CHECK: math.atan2 %[[B]], %[[A]] fastmath<nnan,contract> : f32
+
 // -----
 
 // CHECK-LABEL:   func.func @complex_rsqrt



More information about the Mlir-commits mailing list